whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -0,0 +1,3838 @@
1
+ #include "parakeet.h"
2
+ #include "parakeet-arch.h"
3
+
4
+ #include "ggml.h"
5
+ #include "ggml-cpp.h"
6
+ #include "ggml-alloc.h"
7
+ #include "ggml-backend.h"
8
+
9
+ #include <atomic>
10
+ #include <algorithm>
11
+ #include <cassert>
12
+ #include <cfloat>
13
+ #define _USE_MATH_DEFINES
14
+ #include <cmath>
15
+ #include <climits>
16
+ #include <cstdarg>
17
+ #include <cstdio>
18
+ #include <cstring>
19
+ #include <fstream>
20
+ #include <functional>
21
+ #include <cctype>
22
+ #include <map>
23
+ #include <random>
24
+ #include <set>
25
+ #include <string>
26
+ #include <thread>
27
+ #include <vector>
28
+
29
+ #ifdef _MSC_VER
30
+ #include <codecvt>
31
+ #endif
32
+
33
+ #if defined(PARAKEET_BIG_ENDIAN)
34
+ template<typename T>
35
+ static T byteswap(T value) {
36
+ T value_swapped;
37
+ char * source = reinterpret_cast<char *>(&value);
38
+ char * target = reinterpret_cast<char *>(&value_swapped);
39
+ int size = sizeof(T);
40
+ for (int i = 0; i < size; i++) {
41
+ target[size - 1 - i] = source[i];
42
+ }
43
+ return value_swapped;
44
+ }
45
+
46
+ template<typename T>
47
+ static void byteswap_tensor_data(ggml_tensor * tensor) {
48
+ T * datum = reinterpret_cast<T *>(tensor->data);
49
+ for (int i = 0; i < ggml_nelements(tensor); i++) {
50
+ datum[i] = byteswap(datum[i]);
51
+ }
52
+ }
53
+
54
+ static void byteswap_tensor(ggml_tensor * tensor) {
55
+ switch (tensor->type) {
56
+ case GGML_TYPE_I16: {
57
+ byteswap_tensor_data<int16_t>(tensor);
58
+ break;
59
+ }
60
+ case GGML_TYPE_F16: {
61
+ byteswap_tensor_data<ggml_fp16_t>(tensor);
62
+ break;
63
+ }
64
+ case GGML_TYPE_I32: {
65
+ byteswap_tensor_data<int32_t>(tensor);
66
+ break;
67
+ }
68
+ case GGML_TYPE_F32: {
69
+ byteswap_tensor_data<float>(tensor);
70
+ break;
71
+ }
72
+ default: { // GML_TYPE_I8
73
+ break;
74
+ }
75
+ }
76
+ }
77
+
78
+ #define BYTESWAP_VALUE(d) d = byteswap(d)
79
+ #define BYTESWAP_FILTERS(f) \
80
+ do { \
81
+ for (auto & datum : f.data) { \
82
+ datum = byteswap(datum); \
83
+ } \
84
+ } while (0)
85
+ #define BYTESWAP_TENSOR(t) \
86
+ do { \
87
+ byteswap_tensor(t); \
88
+ } while (0)
89
+ #else
90
+ #define BYTESWAP_VALUE(d) do {} while (0)
91
+ #define BYTESWAP_FILTERS(f) do {} while (0)
92
+ #define BYTESWAP_TENSOR(t) do {} while (0)
93
+ #endif
94
+
95
+ #ifdef __GNUC__
96
+ #ifdef __MINGW32__
97
+ #define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
98
+ #else
99
+ #define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
100
+ #endif
101
+ #else
102
+ #define PARAKEET_ATTRIBUTE_FORMAT(...)
103
+ #endif
104
+
105
+ //
106
+ // logging
107
+ //
108
+
109
+ PARAKEET_ATTRIBUTE_FORMAT(2, 3)
110
+ static void parakeet_log_internal (ggml_log_level level, const char * format, ...);
111
+ static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data);
112
+
113
+ #define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
114
+ #define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
115
+ #define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
116
+
117
+ // define this to enable verbose trace logging - useful for debugging purposes
118
+ //#define PARAKEET_DEBUG
119
+
120
+ #if defined(PARAKEET_DEBUG)
121
+ #define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
122
+ #else
123
+ #define PARAKEET_LOG_DEBUG(...)
124
+ #endif
125
+
126
+ #define PARAKEET_ASSERT(x) \
127
+ do { \
128
+ if (!(x)) { \
129
+ PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
130
+ abort(); \
131
+ } \
132
+ } while (0)
133
+
134
+ #define PARAKEET_MAX_NODES 8192
135
+
136
+ // Threshold for when local attention should be used.
137
+ // 8192 frames x 80ms = 655 s (about 10.9 mins)
138
+ static constexpr int PARAKEET_LOCAL_ATTN_THRESHOLD = 8192;
139
+ // Window of context in each director of the current token.
140
+ // 128 frames * 80ms = 10.24 s
141
+ static constexpr int PARAKEET_LOCAL_ATTN_WINDOW = 128;
142
+
143
+ static std::string format(const char * fmt, ...) {
144
+ va_list ap;
145
+ va_list ap2;
146
+ va_start(ap, fmt);
147
+ va_copy(ap2, ap);
148
+ int size = vsnprintf(NULL, 0, fmt, ap);
149
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
150
+ std::vector<char> buf(size + 1);
151
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
152
+ GGML_ASSERT(size2 == size);
153
+ va_end(ap2);
154
+ va_end(ap);
155
+ return std::string(buf.data(), size);
156
+ }
157
+
158
+ //
159
+ // ggml helpers
160
+ //
161
+
162
+ static bool ggml_graph_compute_helper(
163
+ struct ggml_cgraph * graph,
164
+ int n_threads,
165
+ ggml_abort_callback abort_callback,
166
+ void * abort_callback_data) {
167
+ ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
168
+
169
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
170
+
171
+ auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
172
+ if (set_abort_callback_fn) {
173
+ set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
174
+ }
175
+
176
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
177
+ if (ggml_backend_set_n_threads_fn) {
178
+ ggml_backend_set_n_threads_fn(backend.get(), n_threads);
179
+ }
180
+
181
+ return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
182
+ }
183
+
184
+ static bool ggml_graph_compute_helper(
185
+ ggml_backend_sched_t sched,
186
+ struct ggml_cgraph * graph,
187
+ int n_threads,
188
+ bool sched_reset = true) {
189
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
190
+ ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
191
+ ggml_backend_dev_t dev = ggml_backend_get_device(backend);
192
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
193
+
194
+ auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
195
+ if (fn_set_n_threads) {
196
+ fn_set_n_threads(backend, n_threads);
197
+ }
198
+ }
199
+
200
+ const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
201
+
202
+ if (!t || sched_reset) {
203
+ ggml_backend_sched_reset(sched);
204
+ }
205
+
206
+ return t;
207
+ }
208
+
209
+ // TODO: move these functions to ggml-base with support for ggml-backend?
210
+
211
+
212
+ struct parakeet_mel {
213
+ int n_len = 0;
214
+ int n_len_org = 0;
215
+ int n_mel = 0;
216
+
217
+ std::vector<float> data;
218
+ };
219
+
220
+ struct parakeet_filters {
221
+ int32_t n_mel = 0;
222
+ int32_t n_fb = 0; // number of frequency bins
223
+
224
+ std::vector<float> data;
225
+ };
226
+
227
+ struct parakeet_vocab {
228
+ using id = int32_t;
229
+ using token = std::string;
230
+
231
+ int n_vocab = 8192;
232
+ size_t max_token_length = 0;
233
+
234
+ std::map<token, id> token_to_id;
235
+ std::map<id, token> id_to_token;
236
+
237
+ id token_unk;
238
+ id token_bos;
239
+ id token_blank;
240
+ id token_eos;
241
+ };
242
+
243
+ struct parakeet_segment {
244
+ int64_t t0;
245
+ int64_t t1;
246
+
247
+ std::string text;
248
+
249
+ std::vector<parakeet_token_data> tokens;
250
+ };
251
+
252
+ struct parakeet_batch {
253
+ int32_t n_tokens;
254
+
255
+ parakeet_token * token;
256
+ int32_t * i_time; // index of the audio frame
257
+ parakeet_pos * pos;
258
+ int32_t * n_seq_id; // always 1, here for consistency with llama.cpp
259
+ parakeet_seq_id ** seq_id; // null terminated
260
+ int8_t * logits;
261
+ };
262
+
263
+ // ggml_backend_sched wrapper for parakeet usage
264
+ struct parakeet_sched {
265
+ ggml_backend_sched_t sched = nullptr;
266
+
267
+ std::vector<uint8_t> meta;
268
+ };
269
+
270
+ // TODO: Find out is there a multiple version types. It is not yet clear to me
271
+ // at this point.
272
+ enum parakeet_arch {
273
+ PARAKEET_ARCH_UNKNOWN = 0,
274
+ PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T)
275
+ };
276
+
277
+ struct parakeet_hparams {
278
+ int32_t n_vocab = 8192;
279
+ int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input
280
+ int32_t n_audio_state = 1024;
281
+ int32_t n_audio_head = 8;
282
+ int32_t n_audio_layer = 24;
283
+ int32_t n_mels = 128;
284
+ int32_t ftype = 1;
285
+ int32_t n_fft = 512; // FFT size for mel spectrogram
286
+ float eps = 1e-5f;
287
+ int32_t subsampling_factor = 8;
288
+ int32_t n_subsampling_channels = 256;
289
+ int32_t n_conv_kernel = 9;
290
+ int32_t n_pred_dim = 640;
291
+ int32_t n_pred_layers = 2;
292
+ int32_t n_tdt_durations = 5;
293
+ int32_t n_max_tokens = 10;
294
+
295
+ parakeet_arch arch = PARAKEET_ARCH_TDT;
296
+ };
297
+
298
+ struct parakeet_layer_encoder {
299
+ struct ggml_tensor * norm_ff1_w = nullptr;
300
+ struct ggml_tensor * norm_ff1_b = nullptr;
301
+
302
+ struct ggml_tensor * ff1_linear1_w = nullptr;
303
+ struct ggml_tensor * ff1_linear2_w = nullptr;
304
+
305
+ struct ggml_tensor * norm_conv_w = nullptr;
306
+ struct ggml_tensor * norm_conv_b = nullptr;
307
+
308
+ struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1
309
+ struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv
310
+ struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight
311
+ struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias
312
+ struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean
313
+ struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var
314
+ struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked
315
+ struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2
316
+
317
+ struct ggml_tensor * norm_attn_w = nullptr;
318
+ struct ggml_tensor * norm_attn_b = nullptr;
319
+
320
+ struct ggml_tensor * attn_pos_bias_u = nullptr;
321
+ struct ggml_tensor * attn_pos_bias_v = nullptr;
322
+ struct ggml_tensor * attn_q_w = nullptr;
323
+ struct ggml_tensor * attn_k_w = nullptr;
324
+ struct ggml_tensor * attn_v_w = nullptr;
325
+ struct ggml_tensor * attn_out_w = nullptr;
326
+ struct ggml_tensor * attn_pos_w = nullptr;
327
+
328
+ struct ggml_tensor * norm_ff2_w = nullptr;
329
+ struct ggml_tensor * norm_ff2_b = nullptr;
330
+
331
+ struct ggml_tensor * ff2_linear1_w = nullptr;
332
+ struct ggml_tensor * ff2_linear2_w = nullptr;
333
+
334
+ struct ggml_tensor * norm_out_w = nullptr;
335
+ struct ggml_tensor * norm_out_b = nullptr;
336
+ };
337
+
338
+ struct parakeet_lsmt_layer {
339
+ struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight
340
+ struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight
341
+ struct ggml_tensor * b_h = nullptr; // bias (ih folded into hh at conversion time)
342
+ };
343
+
344
+ struct parakeet_prediction_network {
345
+ struct ggml_tensor * embed_w = nullptr;
346
+
347
+ std::vector<parakeet_lsmt_layer> lstm_layer;
348
+ };
349
+
350
+ struct parakeet_joint_network {
351
+ struct ggml_tensor * pred_w = nullptr;
352
+ struct ggml_tensor * pred_b = nullptr;
353
+ struct ggml_tensor * enc_w = nullptr;
354
+ struct ggml_tensor * enc_b = nullptr;
355
+ struct ggml_tensor * net_w = nullptr;
356
+ struct ggml_tensor * net_b = nullptr;
357
+ };
358
+
359
+ struct parakeet_model {
360
+ parakeet_filters filters;
361
+ parakeet_hparams hparams;
362
+
363
+ struct ggml_tensor * enc_pre_out_w = nullptr;
364
+ struct ggml_tensor * enc_pre_out_b = nullptr;
365
+ struct ggml_tensor * enc_pre_conv_0_w = nullptr;
366
+ struct ggml_tensor * enc_pre_conv_0_b = nullptr;
367
+ struct ggml_tensor * enc_pre_conv_2_w = nullptr;
368
+ struct ggml_tensor * enc_pre_conv_2_b = nullptr;
369
+ struct ggml_tensor * enc_pre_conv_3_w = nullptr;
370
+ struct ggml_tensor * enc_pre_conv_3_b = nullptr;
371
+ struct ggml_tensor * enc_pre_conv_5_w = nullptr;
372
+ struct ggml_tensor * enc_pre_conv_5_b = nullptr;
373
+ struct ggml_tensor * enc_pre_conv_6_w = nullptr;
374
+ struct ggml_tensor * enc_pre_conv_6_b = nullptr;
375
+
376
+ std::vector<parakeet_layer_encoder> layers;
377
+
378
+ parakeet_prediction_network prediction;
379
+
380
+ parakeet_joint_network joint;
381
+
382
+ std::vector<uint32_t> tdt_durations;
383
+
384
+ std::vector<ggml_context *> ctxs;
385
+
386
+ std::vector<ggml_backend_buffer_t> buffers;
387
+
388
+ int n_loaded = 0;
389
+ std::map<std::string, struct ggml_tensor *> tensors;
390
+ };
391
+
392
+ struct parakeet_lstm_state_layer {
393
+ struct ggml_tensor * h_state = nullptr;
394
+ struct ggml_tensor * c_state = nullptr;
395
+ };
396
+
397
+ struct parakeet_lstm_state {
398
+ std::vector<parakeet_lstm_state_layer> layer;
399
+
400
+ std::vector<uint8_t> ctx_buf;
401
+
402
+ ggml_backend_buffer_t buffer = nullptr;
403
+ };
404
+
405
+ struct parakeet_state {
406
+ int64_t t_sample_us = 0;
407
+ int64_t t_encode_us = 0;
408
+ int64_t t_decode_us = 0;
409
+ int64_t t_predict_us = 0;
410
+ int64_t t_predict_build_us = 0; // time spent building the prediction graph
411
+ int64_t t_predict_alloc_us = 0; // time spent in ggml_backend_sched_alloc_graph
412
+ int64_t t_predict_compute_us = 0; // time spent in ggml_graph_compute_helper
413
+ int64_t t_mel_us = 0;
414
+
415
+ int32_t n_sample = 0; // number of tokens sampled
416
+ int32_t n_encode = 0; // number of encoder calls
417
+ int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
418
+ int32_t n_predict = 0; // number of prediction network calls
419
+ int32_t n_fail_p = 0; // number of logprob threshold failures
420
+ int32_t n_fail_h = 0; // number of entropy threshold failures
421
+
422
+ parakeet_mel mel;
423
+
424
+ parakeet_batch batch;
425
+
426
+ int n_frames = 0;
427
+
428
+ std::vector<ggml_backend_t> backends;
429
+
430
+ parakeet_sched sched_encode;
431
+ parakeet_sched sched_decode;
432
+
433
+ // outputs from encoder stages
434
+ struct ggml_tensor * enc_out = nullptr;
435
+ struct ggml_tensor * pred_out = nullptr;
436
+
437
+ std::vector<uint8_t> enc_out_buf;
438
+ ggml_backend_buffer_t enc_out_buffer = nullptr;
439
+
440
+ std::vector<uint8_t> pred_out_buf;
441
+ ggml_backend_buffer_t pred_out_buffer = nullptr;
442
+
443
+ struct ggml_tensor * attn_mask = nullptr;
444
+
445
+ std::vector<float> inp_mel;
446
+ std::vector<float> inp_mask;
447
+
448
+ std::vector<float> logits;
449
+
450
+ std::vector<parakeet_segment> result_all;
451
+
452
+ std::vector<parakeet_token> decoded_tokens;
453
+ std::vector<parakeet_token_data> decoded_token_data;
454
+
455
+ std::string path_model;
456
+
457
+ int32_t n_audio_ctx = 0;
458
+ int32_t sched_encode_n_audio_ctx = 0;
459
+
460
+ parakeet_lstm_state lstm_state;
461
+ };
462
+
463
+ // FFT cache for mel spectrogram computation
464
+ struct parakeet_mel_cache {
465
+ int n_fft = 0;
466
+
467
+ // In FFT, we frequently use sine and cosine operations with the same values.
468
+ // We can use precalculated values to speed up the process.
469
+ std::vector<float> sin_vals;
470
+ std::vector<float> cos_vals;
471
+
472
+ // Hann window (Use cosf to eliminate difference)
473
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
474
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
475
+ std::vector<float> hann_window;
476
+
477
+ // Window function from model (Parakeet uses actual window from training)
478
+ std::vector<float> window;
479
+
480
+ void init(int fft_size) {
481
+ n_fft = fft_size;
482
+ sin_vals.resize(n_fft);
483
+ cos_vals.resize(n_fft);
484
+ hann_window.resize(n_fft);
485
+
486
+ fill_sin_cos_table();
487
+ fill_hann_window(n_fft, true, hann_window.data());
488
+ }
489
+
490
+ void fill_sin_cos_table() {
491
+ for (int i = 0; i < n_fft; i++) {
492
+ double theta = (2 * M_PI * i) / n_fft;
493
+ sin_vals[i] = sinf(theta);
494
+ cos_vals[i] = cosf(theta);
495
+ }
496
+ }
497
+
498
+ void fill_hann_window(int length, bool periodic, float * output) {
499
+ int offset = -1;
500
+ if (periodic) {
501
+ offset = 0;
502
+ }
503
+ for (int i = 0; i < length; i++) {
504
+ output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
505
+ }
506
+ }
507
+ };
508
+
509
+ struct parakeet_context {
510
+ int64_t t_load_us = 0;
511
+ int64_t t_start_us = 0;
512
+
513
+ ggml_type wtype = ggml_type::GGML_TYPE_F16;
514
+ ggml_type itype = ggml_type::GGML_TYPE_F16;
515
+
516
+ parakeet_context_params params;
517
+
518
+ parakeet_model model;
519
+ parakeet_vocab vocab;
520
+
521
+ parakeet_state * state = nullptr;
522
+
523
+ parakeet_mel_cache mel_cache;
524
+
525
+ std::string path_model;
526
+ };
527
+
528
+ struct parakeet_global {
529
+ // We save the log callback globally
530
+ ggml_log_callback log_callback = parakeet_log_callback_default;
531
+ void * log_callback_user_data = nullptr;
532
+ };
533
+
534
+ static parakeet_global g_state;
535
+
536
+ static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81";
537
+
538
+ static inline int utf8_codepoint_len(unsigned char c) {
539
+ if ((c & 0x80) == 0x00) return 1;
540
+ if ((c & 0xE0) == 0xC0) return 2;
541
+ if ((c & 0xF0) == 0xE0) return 3;
542
+ if ((c & 0xF8) == 0xF0) return 4;
543
+ return 1;
544
+ }
545
+
546
+ static bool is_sentencepiece_control(const std::string & piece) {
547
+ return piece == "<unk>" || piece == "<s>" || piece == "</s>" || piece == "[BLANK]";
548
+ }
549
+
550
+ static std::string sentencepiece_normalize(const std::string & text) {
551
+ std::string normalized;
552
+ normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size());
553
+ normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix
554
+
555
+ for (unsigned char c : text) {
556
+ if (std::isspace(c)) {
557
+ normalized += PARAKEET_SPM_SPACE;
558
+ } else {
559
+ normalized += static_cast<char>(c);
560
+ }
561
+ }
562
+
563
+ return normalized;
564
+ }
565
+
566
+ static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) {
567
+ if (is_sentencepiece_control(piece)) {
568
+ return "";
569
+ }
570
+
571
+ std::string text;
572
+ text.reserve(piece.size());
573
+
574
+ size_t pos = 0;
575
+ while (pos < piece.size()) {
576
+ if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) {
577
+ if (!is_first_piece || !text.empty()) {
578
+ text += ' ';
579
+ }
580
+ pos += PARAKEET_SPM_SPACE.size();
581
+ continue;
582
+ }
583
+
584
+ text += piece[pos];
585
+ ++pos;
586
+ }
587
+
588
+ return text;
589
+ }
590
+
591
+
592
+ static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) {
593
+ parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, };
594
+
595
+ batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens));
596
+ batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
597
+ batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens));
598
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
599
+ batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1));
600
+ for (int i = 0; i < n_tokens; ++i) {
601
+ batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id));
602
+ }
603
+ batch.seq_id[n_tokens] = nullptr;
604
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
605
+
606
+ return batch;
607
+ }
608
+
609
+ static void parakeet_batch_free(struct parakeet_batch batch) {
610
+ if (batch.token) free(batch.token);
611
+ if (batch.i_time) free(batch.i_time);
612
+ if (batch.pos) free(batch.pos);
613
+ if (batch.n_seq_id) free(batch.n_seq_id);
614
+ if (batch.seq_id) {
615
+ for (int i = 0; batch.seq_id[i]; ++i) {
616
+ free(batch.seq_id[i]);
617
+ }
618
+ free(batch.seq_id);
619
+ }
620
+ if (batch.logits) free(batch.logits);
621
+ }
622
+
623
+ static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) {
624
+ batch.n_tokens = n_tokens;
625
+ for (int i = 0; i < n_tokens; ++i) {
626
+ if (tokens) {
627
+ batch.token[i] = tokens[i];
628
+ }
629
+ batch.pos [i] = n_past + i;
630
+ batch.n_seq_id[i] = 1;
631
+ batch.seq_id [i][0] = seq_id;
632
+ batch.logits [i] = 0;
633
+ }
634
+ batch.logits[n_tokens - 1] = 1;
635
+ }
636
+
637
+
638
+ static size_t parakeet_sched_size(struct parakeet_sched & allocr) {
639
+ size_t size = allocr.meta.size();
640
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
641
+ ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
642
+ size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
643
+ }
644
+ return size;
645
+ }
646
+
647
+ static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
648
+ auto & sched = allocr.sched;
649
+ auto & meta = allocr.meta;
650
+
651
+ sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true);
652
+
653
+ if (!sched) {
654
+ PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__);
655
+ return false;
656
+ }
657
+
658
+ meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead());
659
+
660
+ if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
661
+ PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
662
+ ggml_backend_sched_free(sched);
663
+ sched = nullptr;
664
+ return false;
665
+ }
666
+
667
+ ggml_backend_sched_reset(sched);
668
+
669
+ return true;
670
+ }
671
+
672
+ static void parakeet_sched_free(struct parakeet_sched & sched) {
673
+ if (sched.sched) {
674
+ ggml_backend_sched_free(sched.sched);
675
+ sched.sched = nullptr;
676
+ }
677
+
678
+ sched.meta.clear();
679
+ }
680
+
681
+
682
+ template<typename T>
683
+ static void read_safe(parakeet_model_loader * loader, T & dest) {
684
+ loader->read(loader->context, &dest, sizeof(T));
685
+ BYTESWAP_VALUE(dest);
686
+ }
687
+
688
+ static bool parakeet_lstm_state_init(
689
+ struct parakeet_state & pstate,
690
+ ggml_backend_t backend,
691
+ int n_layer,
692
+ int n_pred_dim) {
693
+ parakeet_lstm_state & lstm_state = pstate.lstm_state;
694
+
695
+ lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2);
696
+ lstm_state.layer.resize(n_layer);
697
+
698
+ struct ggml_init_params params = {
699
+ /*.mem_size =*/ lstm_state.ctx_buf.size(),
700
+ /*.mem_buffer =*/ lstm_state.ctx_buf.data(),
701
+ /*.no_alloc =*/ true,
702
+ };
703
+
704
+ struct ggml_context * ctx = ggml_init(params);
705
+
706
+ if (!ctx) {
707
+ PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__);
708
+ return false;
709
+ }
710
+
711
+
712
+ for (int il = 0; il < n_layer; ++il) {
713
+ lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim);
714
+ lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim);
715
+ }
716
+
717
+ lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
718
+ if (!lstm_state.buffer) {
719
+ PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__);
720
+ return false;
721
+ }
722
+
723
+ ggml_backend_buffer_clear(lstm_state.buffer, 0);
724
+
725
+ ggml_free(ctx);
726
+
727
+ return true;
728
+ }
729
+
730
+ static bool parakeet_pred_state_init(
731
+ struct parakeet_state & pstate,
732
+ ggml_backend_t backend,
733
+ int n_pred_dim) {
734
+ pstate.pred_out_buf.resize(ggml_tensor_overhead());
735
+
736
+ struct ggml_init_params params = {
737
+ /*.mem_size =*/ pstate.pred_out_buf.size(),
738
+ /*.mem_buffer =*/ pstate.pred_out_buf.data(),
739
+ /*.no_alloc =*/ true,
740
+ };
741
+
742
+ struct ggml_context * ctx = ggml_init(params);
743
+ if (!ctx) {
744
+ PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__);
745
+ return false;
746
+ }
747
+
748
+ pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim);
749
+ pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
750
+ if (!pstate.pred_out_buffer) {
751
+ PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__);
752
+ ggml_free(ctx);
753
+ return false;
754
+ }
755
+
756
+ ggml_free(ctx);
757
+
758
+ return true;
759
+ }
760
+
761
+ static bool parakeet_enc_state_init(
762
+ struct parakeet_state & pstate,
763
+ ggml_backend_t backend,
764
+ int n_audio_state,
765
+ int n_frames_max) {
766
+ pstate.enc_out_buf.resize(ggml_tensor_overhead());
767
+
768
+ struct ggml_init_params params = {
769
+ /*.mem_size =*/ pstate.enc_out_buf.size(),
770
+ /*.mem_buffer =*/ pstate.enc_out_buf.data(),
771
+ /*.no_alloc =*/ true,
772
+ };
773
+
774
+ struct ggml_context * ctx = ggml_init(params);
775
+ if (!ctx) {
776
+ PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor context\n", __func__);
777
+ return false;
778
+ }
779
+
780
+ pstate.enc_out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_frames_max);
781
+ pstate.enc_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
782
+ if (!pstate.enc_out_buffer) {
783
+ PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor\n", __func__);
784
+ ggml_free(ctx);
785
+ return false;
786
+ }
787
+
788
+ ggml_free(ctx);
789
+
790
+ return true;
791
+ }
792
+
793
+ static ggml_backend_t parakeet_backend_init_gpu(const parakeet_context_params & params) {
794
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
795
+
796
+ ggml_backend_dev_t dev = nullptr;
797
+
798
+ int cnt = 0;
799
+ if (params.use_gpu) {
800
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
801
+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
802
+ enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur);
803
+ const char * dev_name = ggml_backend_dev_name(dev_cur);
804
+ PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type);
805
+ if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) {
806
+ PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt);
807
+ if (cnt == params.gpu_device) {
808
+ dev = dev_cur;
809
+ }
810
+
811
+ if (++cnt > params.gpu_device) {
812
+ break;
813
+ }
814
+ }
815
+ }
816
+ }
817
+
818
+ if (dev == nullptr) {
819
+ PARAKEET_LOG_INFO("%s: no GPU found\n", __func__);
820
+ return nullptr;
821
+ }
822
+
823
+ PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
824
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
825
+ if (!result) {
826
+ PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
827
+ }
828
+
829
+ return result;
830
+ }
831
+
832
+ static std::vector<ggml_backend_t> parakeet_backend_init(const parakeet_context_params & params) {
833
+ std::vector<ggml_backend_t> result;
834
+
835
+ ggml_backend_t backend_gpu = parakeet_backend_init_gpu(params);
836
+
837
+ if (backend_gpu) {
838
+ result.push_back(backend_gpu);
839
+ }
840
+
841
+ // ACCEL backends
842
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
843
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
844
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
845
+ PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
846
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
847
+ if (!backend) {
848
+ PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
849
+ continue;
850
+ }
851
+ result.push_back(backend);
852
+ }
853
+ }
854
+
855
+ ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
856
+ if (backend_cpu == nullptr) {
857
+ throw std::runtime_error("failed to initialize CPU backend");
858
+ }
859
+ result.push_back(backend_cpu);
860
+
861
+ return result;
862
+ }
863
+
864
+ using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
865
+
866
+ static buft_list_t make_buft_list(parakeet_context_params & params) {
867
+ // Prio order: GPU -> CPU Extra -> CPU
868
+ buft_list_t buft_list;
869
+
870
+ // GPU
871
+ if (params.use_gpu) {
872
+ int cnt = 0;
873
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
874
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
875
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) {
876
+ if (cnt == params.gpu_device) {
877
+ auto * buft = ggml_backend_dev_buffer_type(dev);
878
+ if (buft) {
879
+ buft_list.emplace_back(dev, buft);
880
+ }
881
+ }
882
+
883
+ if (++cnt > params.gpu_device) {
884
+ break;
885
+ }
886
+ }
887
+ }
888
+ }
889
+
890
+ // CPU Extra
891
+ auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
892
+ auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
893
+ auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
894
+ ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
895
+ if (get_extra_bufts_fn) {
896
+ ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
897
+ while (extra_bufts && *extra_bufts) {
898
+ buft_list.emplace_back(cpu_dev, *extra_bufts);
899
+ ++extra_bufts;
900
+ }
901
+ }
902
+
903
+ // CPU
904
+ buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
905
+
906
+ return buft_list;
907
+ }
908
+
909
+ static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
910
+ bool op_supported = true;
911
+
912
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
913
+ ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU ||
914
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
915
+ // GPU and default CPU backend support all operators
916
+ op_supported = true;
917
+ } else {
918
+ switch (op) {
919
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
920
+ case GGML_OP_GET_ROWS:
921
+ case GGML_OP_MUL_MAT: {
922
+ ggml_init_params params = {
923
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
924
+ /*.mem_buffer =*/ nullptr,
925
+ /*.no_alloc =*/ true,
926
+ };
927
+
928
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
929
+ if (!ctx_ptr) {
930
+ throw std::runtime_error("failed to create ggml context");
931
+ }
932
+ ggml_context * ctx = ctx_ptr.get();
933
+
934
+ ggml_tensor * op_tensor = nullptr;
935
+
936
+ if (op == GGML_OP_MUL_MAT) {
937
+ int64_t n_ctx = hparams.n_audio_ctx;
938
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
939
+ op_tensor = ggml_mul_mat(ctx, w, b);
940
+ } else if (op == GGML_OP_GET_ROWS) {
941
+ int64_t num_indices = 8;
942
+ ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
943
+ op_tensor = ggml_get_rows(ctx, w, indices);
944
+ }
945
+
946
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
947
+ GGML_ASSERT(w->buffer == nullptr);
948
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
949
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
950
+ ggml_backend_buffer_free(w->buffer);
951
+ w->buffer = nullptr;
952
+ break;
953
+ }
954
+ default: {
955
+ op_supported = false;
956
+ break;
957
+ }
958
+ };
959
+ }
960
+
961
+ return op_supported;
962
+ }
963
+
964
+ static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
965
+ GGML_ASSERT(!buft_list.empty());
966
+ for (const auto & p : buft_list) {
967
+ ggml_backend_dev_t dev = p.first;
968
+ ggml_backend_buffer_type_t buft = p.second;
969
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
970
+ return buft;
971
+ }
972
+ }
973
+
974
+ return nullptr;
975
+ }
976
+
977
+
978
+ // load the model from a ggml file
979
+ //
980
+
981
+ // see the convert-parakeet-to-ggml.py script for details
982
+ //
983
+ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) {
984
+ PARAKEET_LOG_INFO("%s: loading model\n", __func__);
985
+
986
+ const int64_t t_start_us = ggml_time_us();
987
+
988
+ wctx.t_start_us = t_start_us;
989
+
990
+ auto & model = wctx.model;
991
+ auto & vocab = wctx.vocab;
992
+
993
+ // verify magic
994
+ {
995
+ uint32_t magic;
996
+ read_safe(loader, magic);
997
+ if (magic != GGML_FILE_MAGIC) {
998
+ PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
999
+ return false;
1000
+ }
1001
+ }
1002
+
1003
+ //load hparams
1004
+ parakeet_hparams hparams;
1005
+ {
1006
+ read_safe(loader, hparams.n_vocab);
1007
+ read_safe(loader, hparams.n_audio_ctx);
1008
+ read_safe(loader, hparams.n_audio_state);
1009
+ read_safe(loader, hparams.n_audio_head);
1010
+ read_safe(loader, hparams.n_audio_layer);
1011
+ read_safe(loader, hparams.n_mels);
1012
+ read_safe(loader, hparams.ftype);
1013
+ read_safe(loader, hparams.n_fft);
1014
+ read_safe(loader, hparams.subsampling_factor);
1015
+ read_safe(loader, hparams.n_subsampling_channels);
1016
+ read_safe(loader, hparams.n_conv_kernel);
1017
+ read_safe(loader, hparams.n_pred_dim);
1018
+ read_safe(loader, hparams.n_pred_layers);
1019
+ read_safe(loader, hparams.n_tdt_durations);
1020
+ read_safe(loader, hparams.n_max_tokens);
1021
+
1022
+ hparams.arch = PARAKEET_ARCH_TDT;
1023
+ wctx.model.hparams = hparams;
1024
+
1025
+ const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
1026
+
1027
+ hparams.ftype %= GGML_QNT_VERSION_FACTOR;
1028
+
1029
+ // for the big tensors, we have the option to store the data in 16-bit floats or quantized
1030
+ // in order to save memory and also to speed up the computation
1031
+ wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype);
1032
+ if (wctx.wtype == GGML_TYPE_COUNT) {
1033
+ PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype);
1034
+ return false;
1035
+ }
1036
+
1037
+ const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown";
1038
+ PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name);
1039
+ PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
1040
+ PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
1041
+ PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
1042
+ PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
1043
+ PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
1044
+ PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
1045
+ PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft);
1046
+ PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps);
1047
+ PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype);
1048
+ PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
1049
+ PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor);
1050
+ PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels);
1051
+ PARAKEET_LOG_INFO("%s: n_conv_kernel = %d\n", __func__, hparams.n_conv_kernel);
1052
+ PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim);
1053
+ PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers);
1054
+ PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations);
1055
+ PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens);
1056
+ }
1057
+
1058
+ // load mel filters
1059
+ {
1060
+ auto & filters = wctx.model.filters;
1061
+
1062
+ read_safe(loader, filters.n_mel);
1063
+ read_safe(loader, filters.n_fb);
1064
+
1065
+ filters.data.resize(filters.n_mel * filters.n_fb);
1066
+ loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
1067
+ BYTESWAP_FILTERS(filters);
1068
+ }
1069
+
1070
+ // load window function
1071
+ {
1072
+ int32_t n_window = 0;
1073
+ read_safe(loader, n_window);
1074
+
1075
+ wctx.mel_cache.window.resize(n_window);
1076
+ loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float));
1077
+
1078
+ #ifdef GGML_BIG_ENDIAN
1079
+ for (auto & datum : wctx.mel_cache.window) {
1080
+ datum = byteswap(datum);
1081
+ }
1082
+ #endif
1083
+
1084
+ PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window);
1085
+ }
1086
+
1087
+ // load TDT (Token and Duration Transducer) values
1088
+ {
1089
+ auto & tdt_durations = wctx.model.tdt_durations;
1090
+ tdt_durations.resize(hparams.n_tdt_durations);
1091
+ loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t));
1092
+
1093
+ PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__);
1094
+ for (const auto value : tdt_durations) {
1095
+ PARAKEET_LOG_INFO("%u ", value);
1096
+ }
1097
+ PARAKEET_LOG_INFO("]\n");
1098
+ }
1099
+
1100
+ // load vocab
1101
+ {
1102
+ int32_t n_vocab = 0;
1103
+ read_safe(loader, n_vocab);
1104
+
1105
+ std::string word;
1106
+ std::vector<char> tmp;
1107
+
1108
+ tmp.reserve(128);
1109
+
1110
+ for (int i = 0; i < n_vocab; i++) {
1111
+ uint32_t len;
1112
+ read_safe(loader, len);
1113
+
1114
+ if (len > 0) {
1115
+ tmp.resize(len);
1116
+ loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
1117
+ word.assign(&tmp[0], tmp.size());
1118
+ } else {
1119
+ PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
1120
+ word = "";
1121
+ }
1122
+
1123
+ vocab.token_to_id[word] = i;
1124
+ vocab.id_to_token[i] = word;
1125
+ vocab.max_token_length = std::max(vocab.max_token_length, word.size());
1126
+ }
1127
+ // Blank token for transducer is at index n_vocab (8192), outside the vocabulary
1128
+ int blank_id = n_vocab;
1129
+ vocab.token_blank = blank_id;
1130
+ vocab.id_to_token[blank_id] = "[BLANK]";
1131
+ vocab.token_to_id["[BLANK]"] = blank_id;
1132
+
1133
+ // Set special token IDs by looking them up in the loaded vocabulary
1134
+ // These are from the SentencePiece vocab file loaded above
1135
+ if (vocab.token_to_id.find("<unk>") != vocab.token_to_id.end()) {
1136
+ vocab.token_unk = vocab.token_to_id.at("<unk>");
1137
+ } else {
1138
+ vocab.token_unk = 0; // Fallback
1139
+ }
1140
+
1141
+ if (vocab.token_to_id.find("<s>") != vocab.token_to_id.end()) {
1142
+ vocab.token_bos = vocab.token_to_id.at("<s>");
1143
+ } else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) {
1144
+ vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>");
1145
+ } else {
1146
+ vocab.token_bos = 0; // Fallback
1147
+ }
1148
+
1149
+ if (vocab.token_to_id.find("</s>") != vocab.token_to_id.end()) {
1150
+ vocab.token_eos = vocab.token_to_id.at("</s>");
1151
+ } else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) {
1152
+ vocab.token_eos = vocab.token_to_id.at("<|endoftext|>");
1153
+ } else {
1154
+ vocab.token_eos = 0; // Fallback
1155
+ }
1156
+
1157
+ vocab.n_vocab = model.hparams.n_vocab;
1158
+
1159
+ PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n",
1160
+ __func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos);
1161
+ }
1162
+
1163
+ const ggml_type wtype = wctx.wtype;
1164
+
1165
+
1166
+ const int n_audio_layer = hparams.n_audio_layer;
1167
+
1168
+ // Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6)
1169
+ size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6;
1170
+
1171
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1172
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1173
+ auto it = ctx_map.find(buft);
1174
+ if (it == ctx_map.end()) {
1175
+ ggml_init_params params = {
1176
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1177
+ /*.mem_buffer =*/ nullptr,
1178
+ /*.no_alloc =*/ true,
1179
+ };
1180
+
1181
+ ggml_context * ctx = ggml_init(params);
1182
+ if (!ctx) {
1183
+ throw std::runtime_error("failed to create ggml context");
1184
+ }
1185
+
1186
+ ctx_map[buft] = ctx;
1187
+ wctx.model.ctxs.emplace_back(ctx);
1188
+
1189
+ return ctx;
1190
+ }
1191
+
1192
+ return it->second;
1193
+ };
1194
+
1195
+ // Create a list of available bufts, in priority order
1196
+ buft_list_t buft_list = make_buft_list(wctx.params);
1197
+
1198
+ auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * {
1199
+ ggml_op op = PARAKEET_TENSOR_INFO.at(type);
1200
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
1201
+ if (!buft) {
1202
+ throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s",
1203
+ PARAKEET_TENSOR_NAMES.at(type)));
1204
+ }
1205
+
1206
+ ggml_context * ctx = get_ctx(buft);
1207
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
1208
+
1209
+ std::string tensor_name;
1210
+ if (layer >= 0) {
1211
+ tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer);
1212
+ } else {
1213
+ tensor_name = PARAKEET_TENSOR_NAMES.at(type);
1214
+ }
1215
+
1216
+ wctx.model.tensors[tensor_name] = tensor;
1217
+
1218
+ return tensor;
1219
+ };
1220
+
1221
+ // prepare tensors for the weights
1222
+
1223
+ ggml_init_params params = {
1224
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1225
+ /*.mem_buffer =*/ nullptr,
1226
+ /*.no_alloc =*/ true,
1227
+ };
1228
+
1229
+ ggml_context * ctx = ggml_init(params);
1230
+
1231
+ const int n_audio_state = hparams.n_audio_state;
1232
+
1233
+ model.layers.resize(n_audio_layer);
1234
+
1235
+ // Encoder pre_encode
1236
+ const int n_subsampling_channels = hparams.n_subsampling_channels;
1237
+ const int n_pre_enc_features = (hparams.n_mels / hparams.subsampling_factor) * n_subsampling_channels;
1238
+ model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_pre_enc_features, n_audio_state));
1239
+ ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w");
1240
+ model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1241
+ ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b");
1242
+
1243
+ model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels));
1244
+ ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w");
1245
+ model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1246
+ ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b");
1247
+
1248
+ model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels));
1249
+ ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w");
1250
+ model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1251
+ ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b");
1252
+
1253
+ model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels));
1254
+ ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w");
1255
+ model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1256
+ ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b");
1257
+
1258
+ model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels));
1259
+ ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w");
1260
+ model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1261
+ ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b");
1262
+
1263
+ model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels));
1264
+ ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w");
1265
+ model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1266
+ ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b");
1267
+
1268
+ // Encoder layers
1269
+ for (int i = 0; i < n_audio_layer; ++i) {
1270
+ auto & layer = model.layers[i];
1271
+
1272
+ // Feed forward 1
1273
+ layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1274
+ layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1275
+ layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
1276
+ ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i);
1277
+ layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
1278
+ ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i);
1279
+
1280
+ // Convolution module
1281
+ layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1282
+ ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i);
1283
+ layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1284
+ ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i);
1285
+ layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i);
1286
+ ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i);
1287
+ layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_conv_kernel, n_audio_state), i);
1288
+ ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i);
1289
+ layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1290
+ ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i);
1291
+ layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1292
+ ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i);
1293
+ layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1294
+ layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1295
+ ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i);
1296
+ layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i);
1297
+ layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1298
+ ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i);
1299
+
1300
+ // Self attention
1301
+ layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1302
+ layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1303
+ layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i);
1304
+ layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i);
1305
+ layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1306
+ layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1307
+ layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1308
+ layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1309
+ layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1310
+ ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i);
1311
+
1312
+ // Feed forward 2
1313
+ layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1314
+ layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1315
+ layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
1316
+ layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
1317
+
1318
+ // Output norm
1319
+ layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1320
+ layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1321
+ }
1322
+
1323
+ // Prediction network (decoder)
1324
+ const int dec_hidden = hparams.n_pred_dim;
1325
+ const int n_pred_embed = hparams.n_vocab + 1; // vocab + blank token
1326
+ const int n_lstm_gates = 4 * dec_hidden; // 4 LSTM gates
1327
+ const int n_joint_out = hparams.n_vocab + hparams.n_tdt_durations + 1; // vocab + durations + blank
1328
+
1329
+ // The prediction/joint hidden dimension is 640, which is not a multiple of the
1330
+ // K-quant block size (256). For K-quant models, we keep these tensors at F32.
1331
+ const int blck = ggml_blck_size(wtype);
1332
+ const ggml_type pred_wtype = (blck > 1 && dec_hidden % blck != 0) ? GGML_TYPE_F32 : wtype;
1333
+ const ggml_type join_wtype = pred_wtype;
1334
+
1335
+ model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_pred_embed));
1336
+ model.prediction.lstm_layer.resize(hparams.n_pred_layers);
1337
+ for (int i = 0; i < hparams.n_pred_layers; ++i) {
1338
+ auto & layer = model.prediction.lstm_layer[i];
1339
+ layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i);
1340
+ ggml_format_name(layer.ih_w, "pred_%d_ih_w", i);
1341
+
1342
+ layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i);
1343
+ ggml_format_name(layer.hh_w, "pred_%d_hh_w", i);
1344
+
1345
+ layer.b_h = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_H, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_lstm_gates), i);
1346
+ ggml_format_name(layer.b_h, "pred_%d_b_h", i);
1347
+ }
1348
+
1349
+ // Joint network
1350
+ model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, dec_hidden));
1351
+ ggml_set_name(model.joint.pred_w, "pred_w");
1352
+ model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden));
1353
+ ggml_set_name(model.joint.pred_b, "pred_b");
1354
+ model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden));
1355
+ ggml_set_name(model.joint.enc_w, "enc_w");
1356
+ model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden));
1357
+ ggml_set_name(model.joint.enc_b, "enc_b");
1358
+ model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, n_joint_out));
1359
+ ggml_set_name(model.joint.net_w, "net_w");
1360
+ model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_joint_out));
1361
+ ggml_set_name(model.joint.net_b, "net_b");
1362
+
1363
+ ggml_free(ctx);
1364
+
1365
+ // allocate tensors in the backend buffers
1366
+ for (auto & p : ctx_map) {
1367
+ ggml_backend_buffer_type_t buft = p.first;
1368
+ ggml_context * ctx = p.second;
1369
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1370
+ if (buf) {
1371
+ wctx.model.buffers.emplace_back(buf);
1372
+
1373
+ size_t size_main = ggml_backend_buffer_get_size(buf);
1374
+ PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
1375
+ }
1376
+ }
1377
+
1378
+ // load weights
1379
+ {
1380
+ size_t total_size = 0;
1381
+
1382
+ auto & tensors_map = wctx.model.tensors;
1383
+ int & n_loaded = wctx.model.n_loaded;
1384
+
1385
+ n_loaded = 0;
1386
+
1387
+ std::vector<char> read_buf;
1388
+
1389
+ while (true) {
1390
+ int32_t n_dims;
1391
+ int32_t length;
1392
+ int32_t ttype;
1393
+
1394
+ read_safe(loader, n_dims);
1395
+ read_safe(loader, length);
1396
+ read_safe(loader, ttype);
1397
+
1398
+ if (loader->eof(loader->context)) {
1399
+ break;
1400
+ }
1401
+
1402
+ int32_t nelements = 1;
1403
+ int32_t ne[4] = { 1, 1, 1, 1 };
1404
+ for (int i = 0; i < n_dims; ++i) {
1405
+ read_safe(loader, ne[i]);
1406
+ nelements *= ne[i];
1407
+ }
1408
+
1409
+ std::string name;
1410
+ std::vector<char> tmp(length); // create a buffer
1411
+ loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
1412
+ name.assign(&tmp[0], tmp.size());
1413
+
1414
+ if (tensors_map.find(name) == tensors_map.end()) {
1415
+ PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1416
+ return false;
1417
+ }
1418
+
1419
+ auto tensor = tensors_map[name.data()];
1420
+
1421
+ if (ggml_nelements(tensor) != nelements) {
1422
+ PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1423
+ PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1424
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1425
+ return false;
1426
+ }
1427
+
1428
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) {
1429
+ PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n",
1430
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]);
1431
+ return false;
1432
+ }
1433
+
1434
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
1435
+
1436
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
1437
+ PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1438
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1439
+ return false;
1440
+ }
1441
+
1442
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
1443
+ // for the CPU and Metal backend, we can read directly into the tensor
1444
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1445
+ BYTESWAP_TENSOR(tensor);
1446
+ } else {
1447
+ // read into a temporary buffer first, then copy to device memory
1448
+ read_buf.resize(ggml_nbytes(tensor));
1449
+
1450
+ loader->read(loader->context, read_buf.data(), read_buf.size());
1451
+
1452
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
1453
+ }
1454
+
1455
+ total_size += ggml_nbytes(tensor);
1456
+ n_loaded++;
1457
+ }
1458
+
1459
+ PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
1460
+
1461
+ if (n_loaded == 0) {
1462
+ PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1463
+ } else if (n_loaded != (int) tensors_map.size()) {
1464
+ PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded);
1465
+ return false;
1466
+ }
1467
+ }
1468
+
1469
+ auto & buffers = wctx.model.buffers;
1470
+ for (auto & buf : buffers) {
1471
+ ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1472
+ }
1473
+
1474
+ wctx.t_load_us = ggml_time_us() - t_start_us;
1475
+
1476
+ return true;
1477
+ }
1478
+
1479
+ // conv subsampling + conformer encoder
1480
+ static struct ggml_cgraph * parakeet_build_graph_encode(parakeet_context & pctx, parakeet_state & pstate) {
1481
+ const auto & model = pctx.model;
1482
+ const auto & hparams = model.hparams;
1483
+ const int n_mel_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx;
1484
+ const int n_mels = hparams.n_mels;
1485
+ const int n_layer = hparams.n_audio_layer;
1486
+ const int n_state = hparams.n_audio_state;
1487
+ const float fc_factor = 0.5f;
1488
+
1489
+ struct ggml_init_params params = {
1490
+ /*.mem_size =*/ pstate.sched_encode.meta.size(),
1491
+ /*.mem_buffer =*/ pstate.sched_encode.meta.data(),
1492
+ /*.no_alloc =*/ true,
1493
+ };
1494
+
1495
+ struct ggml_context * ctx0 = ggml_init(params);
1496
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false);
1497
+
1498
+ // Conv subsampling
1499
+
1500
+ // [freq, time]
1501
+ struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_mel_time, 1, 1);
1502
+ ggml_set_name(mel, "mel");
1503
+ ggml_set_input(mel);
1504
+
1505
+ // [freq, time, channels, batch]
1506
+ struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1);
1507
+ cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b);
1508
+ ggml_set_name(cur, "pre_conv_0");
1509
+
1510
+ cur = ggml_relu(ctx0, cur);
1511
+ ggml_set_name(cur, "pre_conv_0_relu");
1512
+
1513
+ // [freq, time, channels, batch]
1514
+ cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1);
1515
+ cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b);
1516
+ ggml_set_name(cur, "pre_conv_2");
1517
+
1518
+ // [freq, time, channels, batch]
1519
+ cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1);
1520
+ cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b);
1521
+ ggml_set_name(cur, "pre_conv_3");
1522
+
1523
+ cur = ggml_relu(ctx0, cur);
1524
+ ggml_set_name(cur, "pre_conv_3_relu");
1525
+
1526
+ // [freq, time, channels, batch]
1527
+ cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1);
1528
+ ggml_set_name(cur, "pre_conv_5_direct");
1529
+ cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b);
1530
+ ggml_set_name(cur, "pre_conv_5");
1531
+
1532
+ // [freq, time, channels, batch]
1533
+ cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1);
1534
+ cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b);
1535
+ ggml_set_name(cur, "pre_conv_6");
1536
+
1537
+ cur = ggml_relu(ctx0, cur);
1538
+ ggml_set_name(cur, "pre_conv_6_relu");
1539
+
1540
+ // [freq, time, chan]
1541
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1542
+ // [freq, chan, time]
1543
+ cur = ggml_cont(ctx0, cur);
1544
+
1545
+ const int n_freq = cur->ne[0]; // 16
1546
+ const int n_chan = cur->ne[1]; // 256
1547
+ const int n_frames = cur->ne[2]; // time
1548
+
1549
+ // [freq, time, chan, batch] -> [(freq * chan), time]
1550
+ cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames);
1551
+
1552
+ cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur);
1553
+ cur = ggml_add(ctx0, cur, model.enc_pre_out_b);
1554
+
1555
+ ggml_set_name(cur, "pre_enc_out");
1556
+
1557
+ // Encoder
1558
+ // cur: [n_state, n_enc_time]
1559
+
1560
+ const int n_time = cur->ne[1];
1561
+ const bool local_attn = n_time > PARAKEET_LOCAL_ATTN_THRESHOLD;
1562
+ const int att_left = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1;
1563
+ const int att_right = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1;
1564
+ const int window_size = local_attn ? att_left + att_right + 1 : 2 * n_time - 1;
1565
+ const int d_half = n_state / 2;
1566
+ const int mask_dim = local_attn ? window_size : n_time;
1567
+
1568
+ // mask [key, n_time]
1569
+ struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mask_dim, n_time);
1570
+ ggml_set_name(attn_mask, "attn_mask");
1571
+ ggml_set_input(attn_mask);
1572
+
1573
+ struct ggml_tensor * local_mask = nullptr;
1574
+ if (local_attn) {
1575
+ const int chunk = att_left + att_right;
1576
+ local_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, chunk + window_size - 1, chunk);
1577
+ ggml_set_name(local_mask, "local_mask");
1578
+ ggml_set_input(local_mask);
1579
+ }
1580
+
1581
+ struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half);
1582
+ ggml_set_name(pos_freqs, "pos_freqs");
1583
+ ggml_set_input(pos_freqs);
1584
+
1585
+ struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size);
1586
+ ggml_set_name(rel_positions, "rel_positions");
1587
+ ggml_set_input(rel_positions);
1588
+
1589
+ struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1);
1590
+ struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions);
1591
+
1592
+ struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size);
1593
+ struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size);
1594
+ // [n_state, window_size]
1595
+ struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size);
1596
+ ggml_set_name(pos_emb, "pos_emb");
1597
+
1598
+ for (int il = 0; il < n_layer; ++il) {
1599
+ const auto & layer = model.layers[il];
1600
+
1601
+ // FFN1
1602
+ {
1603
+ struct ggml_tensor * residual = cur;
1604
+ ggml_format_name(cur, "enc_%d_res", il);
1605
+
1606
+ // norm
1607
+ cur = ggml_norm(ctx0, cur, hparams.eps);
1608
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff1_w), layer.norm_ff1_b);
1609
+ ggml_format_name(cur, "enc_%d_ffn_norm_1", il);
1610
+
1611
+ // ffn_1
1612
+ cur = ggml_mul_mat(ctx0, layer.ff1_linear1_w, cur);
1613
+ cur = ggml_silu(ctx0, cur);
1614
+ ggml_format_name(cur, "enc_%d_silu", il);
1615
+
1616
+ cur = ggml_mul_mat(ctx0, layer.ff1_linear2_w, cur);
1617
+ ggml_format_name(cur, "enc_%d_ffn_1", il);
1618
+
1619
+ cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
1620
+ ggml_format_name(cur, "enc_%d_res_ffn", il);
1621
+ }
1622
+
1623
+ // self attention block using relative positional encoding computed in graph.
1624
+ {
1625
+ // [feat, time_frames, 1, 1]
1626
+ struct ggml_tensor * residual = cur;
1627
+
1628
+ cur = ggml_norm(ctx0, cur, hparams.eps);
1629
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_attn_w), layer.norm_attn_b);
1630
+ ggml_format_name(cur, "enc_%d_attn_norm", il);
1631
+
1632
+ const int n_head = hparams.n_audio_head;
1633
+ const int d_head = n_state / n_head;
1634
+
1635
+ // [feat, time_frames, 1, 1]
1636
+ struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, layer.attn_q_w, cur);
1637
+ struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, layer.attn_k_w, cur);
1638
+ struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, layer.attn_v_w, cur);
1639
+
1640
+ Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time);
1641
+ K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time);
1642
+ V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time);
1643
+
1644
+ struct ggml_tensor * pos = ggml_mul_mat(ctx0, layer.attn_pos_w, pos_emb);
1645
+ pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size);
1646
+ pos = ggml_cont(ctx0, ggml_permute(ctx0, pos, 0, 2, 1, 3));
1647
+
1648
+ if (local_attn) {
1649
+ const int chunk = att_left + att_right;
1650
+ const int n_group = (n_time + chunk - 1) / chunk;
1651
+ const int n_time_padded = n_group * chunk;
1652
+ const int n_kv_chunk = chunk + window_size - 1;
1653
+ const int n_kv_dense = n_kv_chunk * n_group;
1654
+ const bool need_padding = n_time_padded > n_time;
1655
+
1656
+ Q_cur = ggml_cont(ctx0, ggml_permute(ctx0, Q_cur, 0, 2, 1, 3));
1657
+ K_cur = ggml_cont(ctx0, ggml_permute(ctx0, K_cur, 0, 2, 1, 3));
1658
+ V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 0, 2, 1, 3));
1659
+
1660
+ // content bias
1661
+ struct ggml_tensor * bias_u = ggml_reshape_3d(ctx0, layer.attn_pos_bias_u, d_head, 1, n_head);
1662
+ struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, bias_u);
1663
+
1664
+ // position bias
1665
+ struct ggml_tensor * bias_v = ggml_reshape_3d(ctx0, layer.attn_pos_bias_v, d_head, 1, n_head);
1666
+ struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, bias_v);
1667
+
1668
+ // right pad the time_frame.
1669
+ struct ggml_tensor * Q_u_padded = need_padding ?
1670
+ ggml_pad_ext(ctx0, Q_u, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : Q_u;
1671
+ Q_u_padded = ggml_reshape_4d(ctx0, Q_u_padded, d_head, chunk, n_group, n_head);
1672
+
1673
+ // Add padding to front and back (for the first timeframe and the last timeframe).
1674
+ struct ggml_tensor * K_padded = ggml_pad_ext(ctx0, K_cur, 0, 0, att_left, att_right, 0, 0, 0, 0);
1675
+
1676
+ // pad time axis to match n_kv_dense if needed.
1677
+ if (n_kv_dense > K_padded->ne[1]) {
1678
+ K_padded = ggml_pad_ext(ctx0, K_padded, 0, 0, 0, n_kv_dense - K_padded->ne[1], 0, 0, 0, 0);
1679
+ }
1680
+
1681
+ // Create a 4d tensor where each group spans a wide window of
1682
+ // 512 keys (n_kv_chunk), but moving to the next group (nb[2])
1683
+ // only jumps forward by 256 frames (chunk * nb[1]). This creates
1684
+ // a 256 frame overlap, shared keys in RAM without copies.
1685
+ struct ggml_tensor * K_chunk = ggml_view_4d(ctx0, K_padded,
1686
+ d_head, n_kv_chunk, n_group, n_head,
1687
+ K_padded->nb[1],
1688
+ (size_t) chunk * K_padded->nb[1],
1689
+ K_padded->nb[2],
1690
+ 0);
1691
+ K_chunk = ggml_cont(ctx0, K_chunk);
1692
+
1693
+ struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_chunk, Q_u_padded);
1694
+
1695
+ // The above mul_mat operation, combined with K_chunk's overlapping
1696
+ // frames, produces a dense matrix. But some of the results in
1697
+ // this matrix were computed for keys that aren't part of that
1698
+ // query's window. So we shift each row to keep only the results
1699
+ // that we want.
1700
+ content_scores = ggml_view_4d(ctx0, content_scores,
1701
+ window_size, chunk, n_group, n_head,
1702
+ (size_t) (chunk + window_size) * content_scores->nb[0],
1703
+ content_scores->nb[2],
1704
+ content_scores->nb[3],
1705
+ 0);
1706
+ content_scores = ggml_cont(ctx0, content_scores);
1707
+
1708
+ // ungrouping.
1709
+ content_scores = ggml_reshape_3d(ctx0, content_scores, window_size, n_time_padded, n_head);
1710
+
1711
+ // remove padding if padding was applied (truncating to n_time).
1712
+ if (need_padding) {
1713
+ content_scores = ggml_view_3d(ctx0, content_scores,
1714
+ window_size, n_time, n_head,
1715
+ content_scores->nb[1],
1716
+ content_scores->nb[2],
1717
+ 0);
1718
+ }
1719
+
1720
+ struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v);
1721
+
1722
+ // attention_score = content similarity + relative position scores
1723
+ struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores);
1724
+
1725
+ attn_scores = ggml_soft_max_ext(ctx0, attn_scores, attn_mask, 1.0f / std::sqrt(d_head), 0.0f);
1726
+
1727
+ // right pad the probabilites.
1728
+ struct ggml_tensor * probs_padded = need_padding ?
1729
+ ggml_pad_ext(ctx0, attn_scores, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : attn_scores;
1730
+
1731
+ probs_padded = ggml_reshape_4d(ctx0, probs_padded, window_size, chunk, n_group, n_head);
1732
+ probs_padded = ggml_pad_ext(ctx0, probs_padded, 0, chunk, 0, 0, 0, 0, 0, 0);
1733
+ probs_padded = ggml_view_4d(ctx0, probs_padded,
1734
+ n_kv_chunk, chunk, n_group, n_head,
1735
+ (size_t) n_kv_chunk * probs_padded->nb[0],
1736
+ probs_padded->nb[2],
1737
+ probs_padded->nb[3],
1738
+ 0);
1739
+ probs_padded = ggml_cont(ctx0, probs_padded);
1740
+ probs_padded = ggml_mul(ctx0, probs_padded, local_mask);
1741
+
1742
+ // Add padding to front and back (for the first timeframe and the last timeframe).
1743
+ struct ggml_tensor * V_padded = ggml_pad_ext(ctx0, V_cur, 0, 0, att_left, att_right, 0, 0, 0, 0);
1744
+
1745
+ // pad time axis to match n_kv_dense if needed.
1746
+ if (n_kv_dense > V_padded->ne[1]) {
1747
+ V_padded = ggml_pad_ext(ctx0, V_padded, 0, 0, 0, n_kv_dense - V_padded->ne[1], 0, 0, 0, 0);
1748
+ }
1749
+
1750
+ V_padded = ggml_cont(ctx0, ggml_transpose(ctx0, V_padded));
1751
+
1752
+ struct ggml_tensor * V_chunk = ggml_view_4d(ctx0, V_padded,
1753
+ n_kv_chunk, d_head, n_group, n_head,
1754
+ V_padded->nb[1],
1755
+ (size_t) chunk * V_padded->nb[0],
1756
+ V_padded->nb[2],
1757
+ 0);
1758
+ V_chunk = ggml_cont(ctx0, V_chunk);
1759
+
1760
+ cur = ggml_mul_mat(ctx0, V_chunk, probs_padded);
1761
+ // ungroup.
1762
+ cur = ggml_reshape_3d(ctx0, cur, d_head, n_time_padded, n_head);
1763
+ // unpad
1764
+ if (need_padding) {
1765
+ cur = ggml_view_3d(ctx0, cur, d_head, n_time, n_head, cur->nb[1], cur->nb[2], 0);
1766
+ }
1767
+
1768
+ cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
1769
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_time);
1770
+ cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur);
1771
+ } else {
1772
+ struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_u);
1773
+ ggml_format_name(Q_u, "enc_%d_attn_q_u", il);
1774
+
1775
+ struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3);
1776
+ struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3);
1777
+ struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep);
1778
+ ggml_format_name(content_scores, "enc_%d_attn_content_scores", il);
1779
+
1780
+ struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_v);
1781
+ ggml_format_name(Q_v, "enc_%d_attn_q_v", il);
1782
+
1783
+ Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3);
1784
+ Q_v = ggml_cont(ctx0, Q_v);
1785
+ ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il);
1786
+
1787
+ struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v);
1788
+ ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il);
1789
+
1790
+ // Relative position shifting is performed in the following block.
1791
+ // Some more details on the operations performed below can be found here:
1792
+ // https://github.com/danbev/learning-ai/blob/main/notes/whisper/parakeet.md#relative-position-shift
1793
+ {
1794
+ const auto pos_window = rel_pos_scores->ne[0];
1795
+ const auto n_frame = rel_pos_scores->ne[1];
1796
+ const auto n_head_cur = rel_pos_scores->ne[2];
1797
+
1798
+ rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0);
1799
+ rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0);
1800
+
1801
+ rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head_cur);
1802
+ ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il);
1803
+
1804
+ int center = pos_window / 2;
1805
+ size_t offset = rel_pos_scores->nb[0] * (center+1);
1806
+
1807
+ rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores,
1808
+ n_frame, pos_window, n_head_cur,
1809
+ (pos_window) * 4,
1810
+ rel_pos_scores->nb[2],
1811
+ offset);
1812
+
1813
+ ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il);
1814
+
1815
+ rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores,
1816
+ content_scores->ne[0],
1817
+ content_scores->ne[1],
1818
+ rel_pos_scores->ne[2],
1819
+ rel_pos_scores->nb[1],
1820
+ rel_pos_scores->nb[2],
1821
+ 0);
1822
+ rel_pos_scores = ggml_cont(ctx0, rel_pos_scores);
1823
+ ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il);
1824
+ }
1825
+
1826
+ struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores);
1827
+ ggml_format_name(attn_scores, "enc_%d_attn_scores", il);
1828
+ attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head));
1829
+ attn_scores = ggml_add(ctx0, attn_scores, attn_mask);
1830
+ ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il);
1831
+
1832
+ struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores);
1833
+ ggml_format_name(probs, "enc_%d_attn_probs", il);
1834
+
1835
+ V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3));
1836
+ ggml_format_name(V_cur, "enc_%d_attn_v_cur", il);
1837
+ cur = ggml_mul_mat(ctx0, probs, V_cur);
1838
+ ggml_format_name(cur, "enc_%d_attn_inp", il);
1839
+
1840
+ cur = ggml_permute(ctx0, cur, 2, 0, 1, 3);
1841
+ cur = ggml_cont_2d(ctx0, cur, n_state, n_time);
1842
+ cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur);
1843
+ }
1844
+ ggml_format_name(cur, "enc_%d_attn_out", il);
1845
+
1846
+ cur = ggml_add(ctx0, residual, cur);
1847
+ ggml_format_name(cur, "enc_%d_attn_res", il);
1848
+ }
1849
+
1850
+ // Convolution
1851
+ {
1852
+ struct ggml_tensor * residual = cur;
1853
+ ggml_format_name(cur, "enc_%d_residual_conv", il);
1854
+
1855
+ cur = ggml_norm(ctx0, cur, hparams.eps);
1856
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_conv_w), layer.norm_conv_b);
1857
+ ggml_format_name(cur, "enc_%d_norm_conv", il);
1858
+
1859
+ // pointwise 1d convolution: [1024, 138] -> [2048, 138]
1860
+ cur = ggml_mul_mat(ctx0, layer.conv_pw1_w, cur);
1861
+ ggml_format_name(cur, "enc_%d_conv_pw1", il);
1862
+
1863
+ {
1864
+ int64_t d = cur->ne[0] / 2;
1865
+ struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0);
1866
+ struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]);
1867
+
1868
+ cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate));
1869
+ ggml_format_name(cur, "enc_%d_conv_glu", il);
1870
+ }
1871
+
1872
+ cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
1873
+
1874
+ // use ggml_ssm_conv for f32 precision
1875
+ const int dw_pad = (hparams.n_conv_kernel - 1) / 2;
1876
+ cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0);
1877
+ cur = ggml_roll(ctx0, cur, dw_pad, 0, 0, 0);
1878
+ cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0);
1879
+ ggml_format_name(cur, "enc_%d_conv_dw_pad", il);
1880
+
1881
+ cur = ggml_ssm_conv(ctx0, cur, layer.conv_dw_w);
1882
+ ggml_format_name(cur, "enc_%d_conv_1d_dw", il);
1883
+
1884
+ cur = ggml_sub(ctx0, cur, layer.conv_bn_mean);
1885
+ struct ggml_tensor * std = ggml_sqrt(ctx0, layer.conv_bn_var);
1886
+ cur = ggml_div(ctx0, cur, std);
1887
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.conv_bn_w), layer.conv_bn_b);
1888
+ ggml_format_name(cur, "enc_%d_conv_bn", il);
1889
+
1890
+ cur = ggml_silu(ctx0, cur);
1891
+ ggml_format_name(cur, "enc_%d_conv_silu", il);
1892
+
1893
+ cur = ggml_mul_mat(ctx0, layer.conv_pw2_w, cur);
1894
+ ggml_format_name(cur, "enc_%d_conv_pw2", il);
1895
+
1896
+ cur = ggml_add(ctx0, residual, cur);
1897
+ ggml_format_name(cur, "enc_%d_conv_res", il);
1898
+ }
1899
+
1900
+ // FFN2
1901
+ {
1902
+ struct ggml_tensor * residual = cur;
1903
+ cur = ggml_norm(ctx0, cur, hparams.eps);
1904
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff2_w), layer.norm_ff2_b);
1905
+ ggml_format_name(cur, "enc_%d_ffn_norm_2", il);
1906
+
1907
+ cur = ggml_mul_mat(ctx0, layer.ff2_linear1_w, cur);
1908
+ cur = ggml_silu(ctx0, cur);
1909
+ cur = ggml_mul_mat(ctx0, layer.ff2_linear2_w, cur);
1910
+ cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5));
1911
+ ggml_format_name(cur, "enc_%d_ffn_res", il);
1912
+ }
1913
+
1914
+ cur = ggml_norm(ctx0, cur, hparams.eps);
1915
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_out_w), layer.norm_out_b);
1916
+ }
1917
+
1918
+ ggml_set_name(cur, "encoder_out");
1919
+ pstate.n_frames = cur->ne[1];
1920
+
1921
+ struct ggml_tensor * enc_out_view = ggml_view_2d(ctx0, pstate.enc_out, n_state, pstate.n_frames, pstate.enc_out->nb[1], 0);
1922
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, cur, enc_out_view));
1923
+
1924
+ ggml_free(ctx0);
1925
+
1926
+ return gf;
1927
+ }
1928
+
1929
+ static bool parakeet_encode_internal(
1930
+ parakeet_context & pctx,
1931
+ parakeet_state & pstate,
1932
+ const int mel_offset,
1933
+ const int n_threads,
1934
+ ggml_abort_callback abort_callback,
1935
+ void * abort_callback_data) {
1936
+ const int64_t t_start_us = ggml_time_us();
1937
+
1938
+ auto & sched = pstate.sched_encode.sched;
1939
+
1940
+ ggml_cgraph * gf = parakeet_build_graph_encode(pctx, pstate);
1941
+
1942
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
1943
+ // should never happen as we pre-allocate the memory
1944
+ return false;
1945
+ }
1946
+
1947
+ // set mel input
1948
+ {
1949
+ struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
1950
+
1951
+ const auto & mel_inp = pstate.mel;
1952
+ const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx;
1953
+
1954
+ assert(mel->type == GGML_TYPE_F32);
1955
+ assert(mel_inp.n_mel == pctx.model.hparams.n_mels);
1956
+
1957
+ pstate.inp_mel.resize(ggml_nelements(mel));
1958
+
1959
+ float * dst = pstate.inp_mel.data();
1960
+ memset(dst, 0, ggml_nbytes(mel));
1961
+
1962
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
1963
+ const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len);
1964
+
1965
+ memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, (i1 - i0) * mel_inp.n_mel * sizeof(float));
1966
+
1967
+ ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
1968
+ }
1969
+
1970
+ // set attention mask
1971
+ {
1972
+ struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask");
1973
+ const int n_q = attn_mask->ne[1];
1974
+ const int n_k = attn_mask->ne[0];
1975
+
1976
+ const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor;
1977
+ const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor - 1) / subsampl_factor;
1978
+
1979
+ std::vector<float> mask_data(n_q * n_k);
1980
+ const float mask_value = -1e30f;
1981
+
1982
+ if (n_k == n_q) { // full attention
1983
+ for (int q = 0; q < n_q; ++q) {
1984
+ for (int k = 0; k < n_k; ++k) {
1985
+ mask_data[q * n_k + k] = (k >= n_tokens_real) ? mask_value : 0.0f;
1986
+ }
1987
+ }
1988
+ } else { // local attention
1989
+ const int att_left = n_k / 2;
1990
+ for (int q = 0; q < n_q; ++q) {
1991
+ for (int k = 0; k < n_k; ++k) {
1992
+ const int key = q - att_left + k;
1993
+ mask_data[q * n_k + k] = (key >= 0 && key < n_tokens_real) ? 0.0f : mask_value;
1994
+ }
1995
+ }
1996
+ }
1997
+ ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float));
1998
+ }
1999
+
2000
+ // set local attention skew mask
2001
+ if (struct ggml_tensor * local_mask = ggml_graph_get_tensor(gf, "local_mask")) {
2002
+ const int n_k = local_mask->ne[0];
2003
+ const int n_q = local_mask->ne[1];
2004
+
2005
+ std::vector<float> mask_data(n_q * n_k);
2006
+ const int window_size = n_k - n_q + 1;
2007
+ for (int q = 0; q < n_q; ++q) {
2008
+ for (int k = 0; k < n_k; ++k) {
2009
+ const int rel = k - q;
2010
+ mask_data[q * n_k + k] = (rel >= 0 && rel < window_size) ? 1.0f : 0.0f;
2011
+ }
2012
+ }
2013
+ ggml_backend_tensor_set(local_mask, mask_data.data(), 0, mask_data.size() * sizeof(float));
2014
+ }
2015
+
2016
+ // set positional frequency
2017
+ {
2018
+ struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs");
2019
+ const int d_half = pos_freqs_t->ne[0];
2020
+ const int n_state = pctx.model.hparams.n_audio_state;
2021
+ const float log_10000 = logf(10000.0f);
2022
+ std::vector<float> freqs(d_half);
2023
+ for (int k = 0; k < d_half; ++k) {
2024
+ freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state)));
2025
+ }
2026
+ ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float));
2027
+ }
2028
+
2029
+ // set relative position offsets
2030
+ {
2031
+ struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions");
2032
+ const int window_size = rel_pos_t->ne[1];
2033
+ std::vector<float> pos(window_size);
2034
+ if (window_size == PARAKEET_LOCAL_ATTN_WINDOW * 2 + 1) {
2035
+ for (int t = 0; t < window_size; ++t) {
2036
+ pos[t] = float(PARAKEET_LOCAL_ATTN_WINDOW - t);
2037
+ }
2038
+ } else {
2039
+ const int n_time = (window_size + 1) / 2;
2040
+ for (int t = 0; t < window_size; ++t) {
2041
+ pos[t] = float(n_time - 1 - t);
2042
+ }
2043
+ }
2044
+ ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float));
2045
+ }
2046
+
2047
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2048
+ return false;
2049
+ }
2050
+
2051
+ pstate.t_encode_us += ggml_time_us() - t_start_us;
2052
+ pstate.n_encode++;
2053
+
2054
+ return !(abort_callback && abort_callback(abort_callback_data));
2055
+ }
2056
+
2057
+ static bool parakeet_ensure_encode_sched(
2058
+ parakeet_context & pctx,
2059
+ parakeet_state & pstate,
2060
+ int n_audio_ctx) {
2061
+ if (pstate.sched_encode.sched && pstate.sched_encode_n_audio_ctx == n_audio_ctx) {
2062
+ return true;
2063
+ }
2064
+
2065
+ parakeet_sched_free(pstate.sched_encode);
2066
+
2067
+ const int32_t prev_n_audio_ctx = pstate.n_audio_ctx;
2068
+ pstate.n_audio_ctx = n_audio_ctx;
2069
+
2070
+ const int subsampl_factor = pctx.model.hparams.subsampling_factor;
2071
+ const int n_frames_max = (n_audio_ctx + subsampl_factor - 1) / subsampl_factor;
2072
+ if (n_frames_max > pstate.enc_out->ne[1]) {
2073
+ ggml_backend_buffer_free(pstate.enc_out_buffer);
2074
+ pstate.enc_out_buffer = nullptr;
2075
+ pstate.enc_out = nullptr;
2076
+
2077
+ if (!parakeet_enc_state_init(pstate, pstate.backends[0], pctx.model.hparams.n_audio_state, n_frames_max)) {
2078
+ pstate.sched_encode_n_audio_ctx = 0;
2079
+ pstate.n_audio_ctx = prev_n_audio_ctx;
2080
+ return false;
2081
+ }
2082
+ }
2083
+
2084
+ const bool ok = parakeet_sched_graph_init(pstate.sched_encode, pstate.backends,
2085
+ [&]() {
2086
+ return parakeet_build_graph_encode(pctx, pstate);
2087
+ });
2088
+
2089
+ if (!ok) {
2090
+ pstate.sched_encode_n_audio_ctx = 0;
2091
+ pstate.n_audio_ctx = prev_n_audio_ctx;
2092
+ return false;
2093
+ }
2094
+
2095
+ pstate.sched_encode_n_audio_ctx = n_audio_ctx;
2096
+ return true;
2097
+ }
2098
+
2099
+ static struct ggml_tensor * parakeet_build_graph_lstm_layer(
2100
+ struct ggml_context * ctx0,
2101
+ struct ggml_cgraph * gf,
2102
+ struct ggml_tensor * x_t, // the current input token embedding
2103
+ struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed)
2104
+ struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed)
2105
+ struct ggml_tensor * b_h, // folded ih+hh bias (4 bias tensors packed)
2106
+ struct ggml_tensor * h_state, // this layers hidden state
2107
+ struct ggml_tensor * c_state, // this layers cell state
2108
+ int li) { // layer index (for tensor naming)
2109
+
2110
+ ggml_format_name(x_t, "lstm_layer_%d_x_t", li);
2111
+ ggml_format_name(h_state, "lstm_layer_%d_h_state", li);
2112
+ ggml_format_name(c_state, "lstm_layer_%d_c_state", li);
2113
+
2114
+ // The 4 gates (i, f, o, c) are packed in the same weight tensor.
2115
+ struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t);
2116
+
2117
+ // Hidden-to-Hidden Projections are also packed in the same weight tensor.
2118
+ // b_h holds the folded ih+hh bias (see parakeet_model_load), so it is
2119
+ // the only bias that needs to be added here.
2120
+ struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state);
2121
+ hid_gates = ggml_add(ctx0, hid_gates, b_h);
2122
+
2123
+ // Combine the input and hidden contributions of the gates.
2124
+ struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates);
2125
+ ggml_format_name(gates, "lstm_layer_%d_gates", li);
2126
+
2127
+ const int h_dim = h_state->ne[0];
2128
+ const size_t row_size = ggml_row_size(gates->type, h_dim);
2129
+
2130
+ // The gates are packed as [i, f, o, c] (reordered at convert time, see
2131
+ // parakeet_model_load), so the three sigmoid-gated outputs (i, f, o) are
2132
+ // contiguous and can be computed with a single ggml_sigmoid call.
2133
+ struct ggml_tensor * ifo = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, 3 * h_dim, 0));
2134
+ ggml_format_name(ifo, "lstm_layer_%d_ifo", li);
2135
+
2136
+ // 1. Input Gate at time t.
2137
+ struct ggml_tensor * i_t = ggml_view_1d(ctx0, ifo, h_dim, 0 * row_size);
2138
+ ggml_format_name(i_t, "lstm_layer_%d_i_t", li);
2139
+
2140
+ // Forget gate.
2141
+ struct ggml_tensor * f_t = ggml_view_1d(ctx0, ifo, h_dim, 1 * row_size);
2142
+ ggml_format_name(f_t, "lstm_layer_%d_f_t", li);
2143
+
2144
+ // Output gate.
2145
+ struct ggml_tensor * o_t = ggml_view_1d(ctx0, ifo, h_dim, 2 * row_size);
2146
+ ggml_format_name(o_t, "lstm_layer_%d_o_t", li);
2147
+
2148
+ // Cell gate.
2149
+ struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size));
2150
+ ggml_format_name(c_t, "lstm_layer_%d_c_t", li);
2151
+
2152
+ // Calculate the new cell state.
2153
+ struct ggml_tensor * c_new = ggml_add(ctx0,
2154
+ ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state.
2155
+ ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate.
2156
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state));
2157
+
2158
+ // Calculate the new hidden state.
2159
+ struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new));
2160
+ ggml_set_output(h_new);
2161
+ ggml_format_name(h_new, "lstm_layer_%d_h_new", li);
2162
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state));
2163
+
2164
+ return h_new;
2165
+ }
2166
+
2167
+ static struct ggml_cgraph * parakeet_build_graph_prediction(
2168
+ parakeet_context & pctx,
2169
+ parakeet_state & pstate,
2170
+ const parakeet_batch & batch,
2171
+ bool worst_case) {
2172
+ GGML_UNUSED(worst_case);
2173
+ const auto & model = pctx.model;
2174
+ const auto & hparams = model.hparams;
2175
+ const int n_tokens = batch.n_tokens;
2176
+
2177
+ struct ggml_init_params params = {
2178
+ /*.mem_size =*/ pstate.sched_decode.meta.size(),
2179
+ /*.mem_buffer =*/ pstate.sched_decode.meta.data(),
2180
+ /*.no_alloc =*/ true,
2181
+ };
2182
+
2183
+ struct ggml_context * ctx0 = ggml_init(params);
2184
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false);
2185
+
2186
+ // Prediction Network
2187
+ struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
2188
+ ggml_set_name(token, "token_inp");
2189
+ ggml_set_input(token);
2190
+
2191
+ struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token);
2192
+
2193
+ struct ggml_tensor * inpL = token_embd;
2194
+
2195
+ for (int il = 0; il < hparams.n_pred_layers; ++il) {
2196
+ inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL,
2197
+ model.prediction.lstm_layer[il].ih_w,
2198
+ model.prediction.lstm_layer[il].hh_w,
2199
+ model.prediction.lstm_layer[il].b_h,
2200
+ pstate.lstm_state.layer[il].h_state,
2201
+ pstate.lstm_state.layer[il].c_state,
2202
+ il);
2203
+ }
2204
+
2205
+ struct ggml_tensor * pred_out = inpL;
2206
+ ggml_format_name(pred_out, "lstm_pred_out");
2207
+
2208
+ // Project the prediction network output to the joint network hidden dimension.
2209
+ struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out);
2210
+ pred = ggml_add(ctx0, pred, model.joint.pred_b);
2211
+ ggml_set_name(pred, "h_pred");
2212
+
2213
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, pred, pstate.pred_out));
2214
+
2215
+ ggml_free(ctx0);
2216
+
2217
+ return gf;
2218
+ }
2219
+
2220
+ static struct ggml_cgraph * parakeet_build_graph_joint(
2221
+ parakeet_context & pctx,
2222
+ parakeet_state & pstate,
2223
+ const parakeet_batch & batch,
2224
+ bool worst_case) {
2225
+ GGML_UNUSED(worst_case);
2226
+ const auto & model = pctx.model;
2227
+ const auto & hparams = model.hparams;
2228
+
2229
+ struct ggml_init_params params = {
2230
+ /*.mem_size =*/ pstate.sched_decode.meta.size(),
2231
+ /*.mem_buffer =*/ pstate.sched_decode.meta.data(),
2232
+ /*.no_alloc =*/ true,
2233
+ };
2234
+
2235
+ struct ggml_context * ctx0 = ggml_init(params);
2236
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false);
2237
+
2238
+ struct ggml_tensor * pred = pstate.pred_out;
2239
+ ggml_format_name(pred, "pred");
2240
+
2241
+ const int t_idx = batch.i_time[0];
2242
+ struct ggml_tensor * enc_out = ggml_view_1d(ctx0, pstate.enc_out, hparams.n_audio_state,
2243
+ (size_t) t_idx * pstate.enc_out->nb[1]);
2244
+ ggml_format_name(enc_out, "enc_out_view");
2245
+
2246
+ // Project the encoder output to the joint network hidden dimension.
2247
+ struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out);
2248
+ enc = ggml_add(ctx0, enc, model.joint.enc_b);
2249
+ ggml_set_name(enc, "enc");
2250
+
2251
+ struct ggml_tensor * joint = ggml_add(ctx0, enc, pred);
2252
+ ggml_set_name(joint, "joint");
2253
+ joint = ggml_relu(ctx0, joint);
2254
+
2255
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint);
2256
+ logits = ggml_add(ctx0, logits, model.joint.net_b);
2257
+ ggml_set_output(logits);
2258
+ ggml_set_name(logits, "logits");
2259
+
2260
+ struct ggml_tensor * probs = ggml_soft_max(ctx0, logits);
2261
+ struct ggml_tensor * log_probs = ggml_log(ctx0, probs);
2262
+ ggml_set_output(log_probs);
2263
+ ggml_format_name(log_probs, "log_probs");
2264
+
2265
+ ggml_build_forward_expand(gf, log_probs);
2266
+
2267
+ ggml_free(ctx0);
2268
+
2269
+ return gf;
2270
+ }
2271
+
2272
+ static bool parakeet_predict(
2273
+ parakeet_context & pctx,
2274
+ parakeet_state & pstate,
2275
+ const parakeet_batch & batch,
2276
+ const int n_threads,
2277
+ ggml_abort_callback abort_callback,
2278
+ void * abort_callback_data) {
2279
+
2280
+ const int n_tokens = batch.n_tokens;
2281
+
2282
+ const int64_t t_start_us = ggml_time_us();
2283
+
2284
+ {
2285
+ auto & sched = pstate.sched_decode.sched;
2286
+
2287
+ const int64_t t_build_start_us = ggml_time_us();
2288
+ ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false);
2289
+ pstate.t_predict_build_us += ggml_time_us() - t_build_start_us;
2290
+
2291
+ const int64_t t_alloc_start_us = ggml_time_us();
2292
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2293
+ // should never happen as we pre-allocate the memory
2294
+ return false;
2295
+ }
2296
+ pstate.t_predict_alloc_us += ggml_time_us() - t_alloc_start_us;
2297
+
2298
+ // set the inputs
2299
+ {
2300
+ struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp");
2301
+ ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp));
2302
+ }
2303
+
2304
+ const int64_t t_compute_start_us = ggml_time_us();
2305
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2306
+ return false;
2307
+ }
2308
+ pstate.t_predict_compute_us += ggml_time_us() - t_compute_start_us;
2309
+ }
2310
+
2311
+ pstate.t_predict_us += ggml_time_us() - t_start_us;
2312
+ pstate.n_predict++;
2313
+
2314
+ return !(abort_callback && abort_callback(abort_callback_data));
2315
+ }
2316
+
2317
+ static bool parakeet_joint(
2318
+ parakeet_context & pctx,
2319
+ parakeet_state & pstate,
2320
+ const parakeet_batch & batch,
2321
+ const int n_threads,
2322
+ ggml_abort_callback abort_callback,
2323
+ void * abort_callback_data) {
2324
+ const int64_t t_start_us = ggml_time_us();
2325
+
2326
+ const auto & model = pctx.model;
2327
+ const auto & hparams = model.hparams;
2328
+ const int n_tokens = batch.n_tokens;
2329
+
2330
+ auto & logits_out = pstate.logits;
2331
+
2332
+ struct ggml_tensor * logits;
2333
+
2334
+ {
2335
+ auto & sched = pstate.sched_decode.sched;
2336
+
2337
+ ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false);
2338
+
2339
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2340
+ // should never happen as we pre-allocate the memory
2341
+ return false;
2342
+ }
2343
+
2344
+ logits = ggml_graph_node(gf, -1);
2345
+
2346
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2347
+ return false;
2348
+ }
2349
+
2350
+ }
2351
+
2352
+ const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token
2353
+ logits_out.resize(n_tokens * n_logits);
2354
+ for (int i = 0; i < n_tokens; i++) {
2355
+ if (batch.logits[i] == 0) {
2356
+ continue;
2357
+ }
2358
+ ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits);
2359
+ }
2360
+
2361
+ if (batch.n_tokens == 1) {
2362
+ pstate.t_decode_us += ggml_time_us() - t_start_us;
2363
+ pstate.n_decode++;
2364
+ }
2365
+
2366
+ return !(abort_callback && abort_callback(abort_callback_data));
2367
+ }
2368
+
2369
+ static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) {
2370
+ const std::string & token_str = vocab.id_to_token[token_id];
2371
+ // check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81
2372
+ if (!token_str.empty()) {
2373
+ if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') {
2374
+ return true;
2375
+ }
2376
+ }
2377
+ return false;
2378
+ }
2379
+
2380
+ static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) {
2381
+ const std::string & token_str = vocab.id_to_token[token_id];
2382
+ static const std::string punct_chars = ".,!?;:'\"-()[]{}";
2383
+
2384
+ if (token_str.empty()) {
2385
+ return false;
2386
+ }
2387
+
2388
+ std::string clean_token = token_str;
2389
+ if (clean_token.find("\xE2\x96\x81") == 0) {
2390
+ clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character
2391
+ } else if (clean_token[0] == '_') {
2392
+ clean_token = clean_token.substr(1);
2393
+ }
2394
+
2395
+ return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos;
2396
+ }
2397
+
2398
+ // Collapse punctuation timestamps to match the original Parakeet model.
2399
+ // Punctuations symbols like ',', '.' and others are not spoken words but the
2400
+ // model will still produce a duration for these tokens. But since these are
2401
+ // non-spoken we collapse the timestamps so that they don't have an time duration.
2402
+ static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector<parakeet_token_data> & tokens) {
2403
+ if (tokens.empty()) {
2404
+ return;
2405
+ }
2406
+
2407
+ int64_t last_non_punct_t1 = -1;
2408
+
2409
+ for (size_t i = 0; i < tokens.size(); ++i) {
2410
+ if (is_punctuation_token(vocab, tokens[i].id)) {
2411
+ if (last_non_punct_t1 >= 0) {
2412
+ tokens[i].t0 = last_non_punct_t1;
2413
+ tokens[i].t1 = last_non_punct_t1;
2414
+ }
2415
+ } else {
2416
+ last_non_punct_t1 = tokens[i].t1;
2417
+ }
2418
+ }
2419
+ }
2420
+
2421
+ static parakeet_token_data create_token_data(
2422
+ parakeet_context & pctx,
2423
+ parakeet_state & pstate,
2424
+ parakeet_token token_id,
2425
+ int duration_idx,
2426
+ int duration_value,
2427
+ int frame_index,
2428
+ float token_logit,
2429
+ int n_vocab_logits) {
2430
+
2431
+ float token_sum = 0.0f;
2432
+ for (int i = 0; i < n_vocab_logits; ++i) {
2433
+ token_sum += expf(pstate.logits[i]);
2434
+ }
2435
+ float token_p = expf(token_logit) / token_sum;
2436
+
2437
+ parakeet_token_data token_data;
2438
+ token_data.id = token_id;
2439
+ token_data.duration_idx = duration_idx;
2440
+ token_data.duration_value = duration_value;
2441
+ token_data.frame_index = frame_index;
2442
+ token_data.p = token_p;
2443
+ token_data.plog = token_logit;
2444
+ token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor;
2445
+ token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor;
2446
+ token_data.is_word_start = is_word_start_token(pctx.vocab, token_id);
2447
+
2448
+ return token_data;
2449
+ }
2450
+
2451
+ static bool parakeet_decode(
2452
+ parakeet_context & pctx,
2453
+ parakeet_state & pstate,
2454
+ parakeet_batch & batch,
2455
+ const int n_threads,
2456
+ const parakeet_full_params * params = nullptr) {
2457
+ const auto & hparams = pctx.model.hparams;
2458
+ const auto & tdt_durations = pctx.model.tdt_durations;
2459
+
2460
+ const int n_tdt_durations = hparams.n_tdt_durations;
2461
+ const int n_frames = pstate.n_frames;
2462
+ const int blank_id = pctx.vocab.token_blank;
2463
+ const int n_vocab_logits = blank_id + 1;
2464
+ const int max_tokens_per_timestep = hparams.n_max_tokens;
2465
+
2466
+ // time index into the encoder frame (current time frame)
2467
+ int t = 0;
2468
+ // number of symbols emitted for the current time frame
2469
+ int tokens_emitted = 0;
2470
+
2471
+ // Start with the blank token (8192)
2472
+ parakeet_token last_token = blank_id;
2473
+
2474
+ PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames);
2475
+
2476
+ batch.n_tokens = 1;
2477
+ batch.token[0] = last_token;
2478
+ batch.logits[0] = 1;
2479
+ batch.i_time[0] = 0;
2480
+
2481
+ // run the prediction network for the initial blank token. This will
2482
+ // initialize the LSTM state and produce an initial hidden state that can
2483
+ // be used in the joint network below.
2484
+ if (!parakeet_predict(pctx, pstate, batch, n_threads,
2485
+ params ? params->abort_callback : nullptr,
2486
+ params ? params->abort_callback_user_data : nullptr)) {
2487
+ return false;
2488
+ }
2489
+
2490
+ // process all time frames of the encoder output
2491
+ while (t < n_frames) {
2492
+ batch.n_tokens = 1;
2493
+ batch.i_time[0] = t;
2494
+ batch.logits[0] = 1;
2495
+
2496
+ // Use the current encoder frame (t) and the output of the prediction to
2497
+ // generate probabilities for the next token and duration. batch.i_time
2498
+ // is used in to select the correct frame from the encoder output.
2499
+ // The joint network outputs logits for all the tokens in the vocabulary
2500
+ // plus the blank token, and also n_duration logits for the duration
2501
+ // tokens which contain information about how many frames to skip/advance forward.
2502
+ if (!parakeet_joint(pctx, pstate, batch, n_threads,
2503
+ params ? params->abort_callback : nullptr,
2504
+ params ? params->abort_callback_user_data : nullptr)) {
2505
+ return false;
2506
+ }
2507
+
2508
+ const int64_t t_start_sample_us = ggml_time_us();
2509
+
2510
+ // find the best token (greedy).
2511
+ // TODO: implement beam search?
2512
+ int best_token = 0;
2513
+ float max_logit = -1e10f;
2514
+ for (int i = 0; i < n_vocab_logits; ++i) {
2515
+ if (pstate.logits[i] > max_logit) {
2516
+ max_logit = pstate.logits[i];
2517
+ best_token = i;
2518
+ }
2519
+ }
2520
+
2521
+ // find the max index of the duration logits, and look up that index
2522
+ // value in the tdt_durations array to get the actual duration value.
2523
+ int best_duration_idx = 0;
2524
+ float best_duration_logit = -1e10f;
2525
+ for (int i = 0; i < n_tdt_durations; ++i) {
2526
+ if (pstate.logits[n_vocab_logits + i] > best_duration_logit) {
2527
+ best_duration_logit = pstate.logits[n_vocab_logits + i];
2528
+ best_duration_idx = i;
2529
+ }
2530
+ }
2531
+ // look up that max duration index value in the tdt_durations array to
2532
+ // get the actual duration value.
2533
+ int duration = tdt_durations[best_duration_idx];
2534
+
2535
+ if (best_token == blank_id) {
2536
+ if (duration == 0) {
2537
+ duration = 1;
2538
+ }
2539
+ // skip forward by duration time frames.
2540
+ t += duration;
2541
+ // reset symbols emitted counter
2542
+ tokens_emitted = 0;
2543
+ // continue without predicting.
2544
+ continue;
2545
+ }
2546
+
2547
+ // Emit non-blank token at current frame t.
2548
+ pstate.decoded_tokens.push_back(best_token);
2549
+ pstate.t_sample_us += ggml_time_us() - t_start_sample_us;
2550
+ pstate.n_sample++;
2551
+
2552
+ parakeet_token_data token_data = create_token_data(
2553
+ pctx, pstate, best_token, best_duration_idx, duration, t,
2554
+ max_logit, n_vocab_logits);
2555
+
2556
+ pstate.decoded_token_data.push_back(token_data);
2557
+
2558
+ // Call token callback if registered (for real-time streaming)
2559
+ if (params && params->new_token_callback) {
2560
+ params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data);
2561
+ }
2562
+
2563
+ last_token = best_token;
2564
+
2565
+ // advance predictor for the non-blank token.
2566
+ batch.token[0] = last_token;
2567
+ if (!parakeet_predict(pctx, pstate, batch, n_threads,
2568
+ params ? params->abort_callback : nullptr,
2569
+ params ? params->abort_callback_user_data : nullptr)) {
2570
+ return false;
2571
+ }
2572
+
2573
+ // if duration greater than 0, continue looping over the encoder frames
2574
+ // and skip to the updated time frame (t).
2575
+ if (duration > 0) {
2576
+ t += duration;
2577
+ tokens_emitted = 0;
2578
+ continue;
2579
+ }
2580
+
2581
+ // if duration is zero we stay on the current time frame.
2582
+ tokens_emitted++;
2583
+ if (tokens_emitted >= max_tokens_per_timestep) {
2584
+ t += 1; // forced blank/time advance behavior
2585
+ tokens_emitted = 0;
2586
+ }
2587
+ }
2588
+
2589
+ return true;
2590
+ }
2591
+
2592
+ // 500 -> 00:05.000
2593
+ // 6000 -> 01:00.000
2594
+ // naive Discrete Fourier Transform
2595
+ // input is real-valued
2596
+ // output is complex-valued
2597
+ static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) {
2598
+ const int sin_cos_step = cache.n_fft / N;
2599
+
2600
+ for (int k = 0; k < N; k++) {
2601
+ float re = 0;
2602
+ float im = 0;
2603
+
2604
+ for (int n = 0; n < N; n++) {
2605
+ int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N
2606
+ re += in[n]*cache.cos_vals[idx]; // cos(t)
2607
+ im -= in[n]*cache.sin_vals[idx]; // sin(t)
2608
+ }
2609
+
2610
+ out[k*2 + 0] = re;
2611
+ out[k*2 + 1] = im;
2612
+ }
2613
+ }
2614
+
2615
+ // Cooley-Tukey FFT
2616
+ // poor man's implementation - use something better
2617
+ // input is real-valued
2618
+ // output is complex-valued
2619
+ static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) {
2620
+ if (N == 1) {
2621
+ out[0] = in[0];
2622
+ out[1] = 0;
2623
+ return;
2624
+ }
2625
+
2626
+ const int half_N = N / 2;
2627
+ if (N - half_N*2 == 1) {
2628
+ dft(in, N, out, cache);
2629
+ return;
2630
+ }
2631
+
2632
+ float* even = in + N;
2633
+ for (int i = 0; i < half_N; ++i) {
2634
+ even[i]= in[2*i];
2635
+ }
2636
+ float* even_fft = out + 2 * N;
2637
+ fft(even, half_N, even_fft, cache);
2638
+
2639
+ float* odd = even;
2640
+ for (int i = 0; i < half_N; ++i) {
2641
+ odd[i] = in[2*i + 1];
2642
+ }
2643
+ float* odd_fft = even_fft + N;
2644
+ fft(odd, half_N, odd_fft, cache);
2645
+
2646
+ const int sin_cos_step = cache.n_fft / N;
2647
+ for (int k = 0; k < half_N; k++) {
2648
+ int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2649
+ float re = cache.cos_vals[idx]; // cos(t)
2650
+ float im = -cache.sin_vals[idx]; // sin(t)
2651
+
2652
+ float re_odd = odd_fft[2*k + 0];
2653
+ float im_odd = odd_fft[2*k + 1];
2654
+
2655
+ out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2656
+ out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2657
+
2658
+ out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2659
+ out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2660
+ }
2661
+ }
2662
+
2663
+ struct mel_worker_params {
2664
+ int ith;
2665
+ int window_size;
2666
+ int n_samples;
2667
+ int frame_size;
2668
+ int frame_step;
2669
+ int n_threads;
2670
+ };
2671
+
2672
+ static void log_mel_spectrogram_worker_thread(
2673
+ mel_worker_params params,
2674
+ const float * window_func,
2675
+ const std::vector<float> & samples,
2676
+ const parakeet_filters & filters,
2677
+ parakeet_mel & mel,
2678
+ const parakeet_mel_cache & cache) {
2679
+ std::vector<float> fft_in(params.frame_size * 2, 0.0);
2680
+ std::vector<float> fft_out(params.frame_size * 2 * 2 * 2);
2681
+
2682
+ int n_fb = filters.n_fb; // number of frequency bins
2683
+ int i = params.ith;
2684
+
2685
+ // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist
2686
+ assert(n_fb == 1 + (params.frame_size / 2));
2687
+
2688
+ const double eps = 5.960464477539063e-08;
2689
+
2690
+ // calculate FFT only when fft_in are not all zero
2691
+ for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) {
2692
+ const int offset = i * params.frame_step;
2693
+
2694
+ const int window_pad_left = (params.frame_size - params.window_size) / 2;
2695
+
2696
+ // Zero-pad left
2697
+ std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f);
2698
+
2699
+ // Apply windowed samples in the center
2700
+ const int n_to_process = std::min({params.window_size, params.n_samples - offset});
2701
+ for (int j = 0; j < n_to_process; j++) {
2702
+ fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j];
2703
+ }
2704
+
2705
+ // Zero-pad right (and any samples we didn't have)
2706
+ std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f);
2707
+
2708
+ // FFT
2709
+ fft(fft_in.data(), params.frame_size, fft_out.data(), cache);
2710
+
2711
+ // Calculate modulus^2 of complex numbers
2712
+ // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
2713
+ for (int j = 0; j < n_fb; j++) {
2714
+ fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
2715
+ }
2716
+
2717
+ // mel spectrogram
2718
+ for (int j = 0; j < mel.n_mel; j++) {
2719
+ double sum = 0.0;
2720
+ // unroll loop (suggested by GH user @lunixbochs)
2721
+ int k = 0;
2722
+ for (k = 0; k < n_fb - 3; k += 4) {
2723
+ sum +=
2724
+ fft_out[k + 0] * filters.data[j * n_fb + k + 0] +
2725
+ fft_out[k + 1] * filters.data[j * n_fb + k + 1] +
2726
+ fft_out[k + 2] * filters.data[j * n_fb + k + 2] +
2727
+ fft_out[k + 3] * filters.data[j * n_fb + k + 3];
2728
+ }
2729
+ // handle n_fb remainder
2730
+ for (; k < n_fb; k++) {
2731
+ sum += fft_out[k] * filters.data[j * n_fb + k];
2732
+ }
2733
+
2734
+ mel.data[i * mel.n_mel + j] = std::log(sum + eps);
2735
+ }
2736
+ }
2737
+
2738
+ // Otherwise fft_out are all zero - use log(eps) for consistency
2739
+ const double empty_sum = std::log(eps);
2740
+ for (; i < mel.n_len; i += params.n_threads) {
2741
+ for (int j = 0; j < mel.n_mel; j++) {
2742
+ mel.data[i * mel.n_mel + j] = empty_sum;
2743
+ }
2744
+ }
2745
+ }
2746
+
2747
+ static bool log_mel_spectrogram(
2748
+ parakeet_state & wstate,
2749
+ const float * samples,
2750
+ const int n_samples,
2751
+ const int /*sample_rate*/,
2752
+ const int frame_size,
2753
+ const int frame_step,
2754
+ const int n_mel,
2755
+ const int n_threads,
2756
+ const parakeet_filters & filters,
2757
+ const bool debug,
2758
+ parakeet_mel & mel,
2759
+ const parakeet_mel_cache & cache) {
2760
+ const int64_t t_start_us = ggml_time_us();
2761
+
2762
+ const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data();
2763
+ const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size();
2764
+
2765
+ std::vector<float> samples_preprocessed(samples, samples + n_samples);
2766
+
2767
+ // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1]
2768
+ {
2769
+ const float preemph = 0.97f;
2770
+ for (int i = n_samples - 1; i > 0; i--) {
2771
+ samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1];
2772
+ }
2773
+ }
2774
+
2775
+ // Parakeet Pytorch implementation uses centered contant padding.
2776
+ const size_t pad = (size_t)(frame_size / 2);
2777
+ std::vector<float> samples_padded(n_samples + 2 * pad, 0.0f);
2778
+ std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad);
2779
+
2780
+ mel.n_mel = n_mel;
2781
+ mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1;
2782
+ mel.n_len_org = mel.n_len;
2783
+ mel.data.resize(mel.n_mel * mel.n_len);
2784
+
2785
+ // Worker Threads (STFT + Mel + Natural Log)
2786
+ {
2787
+ std::vector<std::thread> workers(n_threads - 1);
2788
+ const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads };
2789
+
2790
+ for (int iw = 0; iw < n_threads - 1; ++iw) {
2791
+ mel_worker_params params = mel_params;
2792
+ params.ith = iw + 1;
2793
+ workers[iw] = std::thread(log_mel_spectrogram_worker_thread,
2794
+ params,
2795
+ window_func,
2796
+ std::cref(samples_padded),
2797
+ std::cref(filters),
2798
+ std::ref(mel),
2799
+ std::cref(cache));
2800
+ }
2801
+
2802
+ log_mel_spectrogram_worker_thread(
2803
+ mel_params,
2804
+ window_func,
2805
+ samples_padded,
2806
+ filters,
2807
+ mel,
2808
+ cache);
2809
+
2810
+ for (int iw = 0; iw < n_threads - 1; ++iw) {
2811
+ workers[iw].join();
2812
+ }
2813
+ }
2814
+
2815
+ {
2816
+ const double eps = 1e-5;
2817
+ int valid_frames = n_samples / frame_step;
2818
+
2819
+ for (int j = 0; j < mel.n_mel; j++) {
2820
+ double sum = 0.0;
2821
+ double sq_diff_sum = 0.0;
2822
+
2823
+ // Calculate Mean ONLY on valid audio frames
2824
+ for (int i = 0; i < valid_frames; i++) {
2825
+ sum += (double)mel.data[i * mel.n_mel + j];
2826
+ }
2827
+ double mean = sum / valid_frames;
2828
+
2829
+ // Calculate Variance ONLY on valid audio frames
2830
+ for (int i = 0; i < valid_frames; i++) {
2831
+ double diff = (double)mel.data[i * mel.n_mel + j] - mean;
2832
+ sq_diff_sum += diff * diff;
2833
+ }
2834
+
2835
+ double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0));
2836
+ double denominator = std_dev + eps;
2837
+
2838
+ // Apply to ALL frames (including the padded ones)
2839
+ for (int i = 0; i < mel.n_len; i++) {
2840
+ mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator);
2841
+ }
2842
+ }
2843
+ }
2844
+
2845
+ wstate.t_mel_us += ggml_time_us() - t_start_us;
2846
+
2847
+ if (debug) {
2848
+ std::ofstream outFile("log_mel_spectrogram.json");
2849
+ outFile << "[";
2850
+ for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
2851
+ outFile << mel.data[i] << ", ";
2852
+ }
2853
+ outFile << mel.data[mel.data.size() - 1] << "]";
2854
+ outFile.close();
2855
+ }
2856
+
2857
+ return true;
2858
+ }
2859
+
2860
+ static std::vector<parakeet_vocab::id> tokenize(const parakeet_vocab & vocab, const std::string & text) {
2861
+ std::vector<parakeet_vocab::id> tokens;
2862
+ const std::string normalized = sentencepiece_normalize(text);
2863
+
2864
+ size_t i = 0;
2865
+ while (i < normalized.size()) {
2866
+ const size_t remaining = normalized.size() - i;
2867
+ const size_t max_len = std::min(vocab.max_token_length, remaining);
2868
+
2869
+ bool found = false;
2870
+ for (size_t len = max_len; len > 0; --len) {
2871
+ const auto it = vocab.token_to_id.find(normalized.substr(i, len));
2872
+ if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) {
2873
+ tokens.push_back(it->second);
2874
+ i += len;
2875
+ found = true;
2876
+ break;
2877
+ }
2878
+ }
2879
+
2880
+ if (!found) {
2881
+ if (vocab.token_unk >= 0) {
2882
+ tokens.push_back(vocab.token_unk);
2883
+ }
2884
+
2885
+ const unsigned char c = static_cast<unsigned char>(normalized[i]);
2886
+ i += utf8_codepoint_len(c);
2887
+ }
2888
+ }
2889
+
2890
+ return tokens;
2891
+ }
2892
+
2893
+
2894
+ //
2895
+ // interface implementation
2896
+ //
2897
+
2898
+ struct parakeet_state * parakeet_init_state(parakeet_context * ctx) {
2899
+ parakeet_state * state = new parakeet_state;
2900
+
2901
+ state->backends = parakeet_backend_init(ctx->params);
2902
+ if (state->backends.empty()) {
2903
+ PARAKEET_LOG_ERROR("%s: parakeet_backend_init() failed\n", __func__);
2904
+ parakeet_free_state(state);
2905
+ return nullptr;
2906
+ }
2907
+
2908
+ const int batch_size = ctx->model.hparams.n_audio_ctx;
2909
+
2910
+ state->logits.reserve(ctx->vocab.n_vocab * batch_size);
2911
+
2912
+ state->batch = parakeet_batch_init(batch_size);
2913
+
2914
+ {
2915
+ const int n_audio_state = ctx->model.hparams.n_audio_state;
2916
+ const int subsampl_factor = ctx->model.hparams.subsampling_factor;
2917
+ const int n_frames_max = (batch_size + subsampl_factor - 1) / subsampl_factor;
2918
+
2919
+ if (!parakeet_enc_state_init(*state, state->backends[0], n_audio_state, n_frames_max)) {
2920
+ PARAKEET_LOG_ERROR("%s: parakeet_enc_state_init() failed\n", __func__);
2921
+ parakeet_free_state(state);
2922
+ return nullptr;
2923
+ }
2924
+
2925
+ const size_t mem_enc_ctx = state->enc_out_buf.size();
2926
+ const size_t mem_enc_out_buf = ggml_backend_buffer_get_size(state->enc_out_buffer);
2927
+ PARAKEET_LOG_INFO("%s: enc_out state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__,
2928
+ mem_enc_ctx / 1024.0 / 1024.0, mem_enc_out_buf / 1024.0 / 1024.0);
2929
+ }
2930
+
2931
+ // conv/encoder allocator
2932
+ bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends,
2933
+ [&]() {
2934
+ return parakeet_build_graph_encode(*ctx, *state);
2935
+ });
2936
+
2937
+ if (!ok) {
2938
+ PARAKEET_LOG_ERROR("%s: failed to init encode allocator\n", __func__);
2939
+ parakeet_free_state(state);
2940
+ return nullptr;
2941
+ }
2942
+ state->sched_encode_n_audio_ctx = state->n_audio_ctx > 0 ? state->n_audio_ctx : ctx->model.hparams.n_audio_ctx;
2943
+
2944
+ if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers, ctx->model.hparams.n_pred_dim)) {
2945
+ PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__);
2946
+ parakeet_free_state(state);
2947
+ return nullptr;
2948
+ }
2949
+
2950
+ {
2951
+ const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size();
2952
+ const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer);
2953
+ PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__,
2954
+ mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0);
2955
+ }
2956
+
2957
+ if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) {
2958
+ PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__);
2959
+ parakeet_free_state(state);
2960
+ return nullptr;
2961
+ }
2962
+
2963
+ {
2964
+ const size_t mem_pred_ctx = state->pred_out_buf.size();
2965
+ const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer);
2966
+ PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__,
2967
+ mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0);
2968
+ }
2969
+
2970
+ PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6);
2971
+
2972
+ {
2973
+ bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends,
2974
+ [&]() {
2975
+ const auto & hparams = ctx->model.hparams;
2976
+ const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet
2977
+
2978
+ parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0);
2979
+
2980
+ return parakeet_build_graph_prediction(*ctx, *state, state->batch, true);
2981
+ });
2982
+
2983
+ if (!ok) {
2984
+ PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__);
2985
+ parakeet_free_state(state);
2986
+ return nullptr;
2987
+ }
2988
+
2989
+ PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6);
2990
+ }
2991
+
2992
+ return state;
2993
+ }
2994
+
2995
+ struct parakeet_context_params parakeet_context_default_params() {
2996
+ struct parakeet_context_params result = {
2997
+ /*.use_gpu =*/ true,
2998
+ /*.gpu_device =*/ 0,
2999
+ };
3000
+ return result;
3001
+ }
3002
+
3003
+ struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) {
3004
+ PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3005
+ #ifdef _MSC_VER
3006
+ // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
3007
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
3008
+ std::wstring path_model_wide = converter.from_bytes(path_model);
3009
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
3010
+ #else
3011
+ auto fin = std::ifstream(path_model, std::ios::binary);
3012
+ #endif
3013
+ if (!fin) {
3014
+ PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3015
+ return nullptr;
3016
+ }
3017
+
3018
+ parakeet_model_loader loader = {};
3019
+
3020
+ loader.context = &fin;
3021
+
3022
+ loader.read = [](void * ctx, void * output, size_t read_size) {
3023
+ std::ifstream * fin = (std::ifstream*)ctx;
3024
+ fin->read((char *)output, read_size);
3025
+ return read_size;
3026
+ };
3027
+
3028
+ loader.eof = [](void * ctx) {
3029
+ std::ifstream * fin = (std::ifstream*)ctx;
3030
+ return fin->eof();
3031
+ };
3032
+
3033
+ loader.close = [](void * ctx) {
3034
+ std::ifstream * fin = (std::ifstream*)ctx;
3035
+ fin->close();
3036
+ };
3037
+
3038
+ auto ctx = parakeet_init_with_params_no_state(&loader, params);
3039
+
3040
+ if (ctx) {
3041
+ ctx->path_model = path_model;
3042
+ }
3043
+
3044
+ return ctx;
3045
+ }
3046
+
3047
+ struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) {
3048
+ struct buf_context {
3049
+ uint8_t* buffer;
3050
+ size_t size;
3051
+ size_t current_offset;
3052
+ };
3053
+
3054
+ buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
3055
+
3056
+ PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__);
3057
+
3058
+ parakeet_model_loader loader = {};
3059
+
3060
+ loader.context = &ctx;
3061
+
3062
+ loader.read = [](void * ctx, void * output, size_t read_size) {
3063
+ buf_context * buf = reinterpret_cast<buf_context *>(ctx);
3064
+
3065
+ size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
3066
+
3067
+ memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
3068
+ buf->current_offset += size_to_copy;
3069
+
3070
+ return size_to_copy;
3071
+ };
3072
+
3073
+ loader.eof = [](void * ctx) {
3074
+ buf_context * buf = reinterpret_cast<buf_context *>(ctx);
3075
+
3076
+ return buf->current_offset >= buf->size;
3077
+ };
3078
+
3079
+ loader.close = [](void * /*ctx*/) { };
3080
+
3081
+ return parakeet_init_with_params_no_state(&loader, params);
3082
+ }
3083
+
3084
+ struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) {
3085
+ ggml_time_init();
3086
+
3087
+ PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
3088
+ PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3089
+ PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
3090
+ PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
3091
+
3092
+ parakeet_context * ctx = new parakeet_context;
3093
+ ctx->params = params;
3094
+
3095
+ bool model_loaded = false;
3096
+ try {
3097
+ model_loaded = parakeet_model_load(loader, *ctx);
3098
+ } catch (const std::exception & e) {
3099
+ PARAKEET_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what());
3100
+ } catch (...) {
3101
+ PARAKEET_LOG_ERROR("%s: unknown exception during model load\n", __func__);
3102
+ }
3103
+
3104
+ if (!model_loaded) {
3105
+ loader->close(loader->context);
3106
+ PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__);
3107
+ delete ctx;
3108
+ return nullptr;
3109
+ }
3110
+
3111
+ loader->close(loader->context);
3112
+
3113
+ // Initialize mel cache with model's FFT size
3114
+ ctx->mel_cache.init(ctx->model.hparams.n_fft);
3115
+ PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft);
3116
+
3117
+ return ctx;
3118
+ }
3119
+
3120
+ struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) {
3121
+ parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params);
3122
+ if (!ctx) {
3123
+ return nullptr;
3124
+ }
3125
+
3126
+ ctx->state = parakeet_init_state(ctx);
3127
+ if (!ctx->state) {
3128
+ parakeet_free(ctx);
3129
+ return nullptr;
3130
+ }
3131
+
3132
+ return ctx;
3133
+ }
3134
+
3135
+ struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) {
3136
+ parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
3137
+ if (!ctx) {
3138
+ return nullptr;
3139
+ }
3140
+
3141
+ ctx->state = parakeet_init_state(ctx);
3142
+ if (!ctx->state) {
3143
+ parakeet_free(ctx);
3144
+ return nullptr;
3145
+ }
3146
+
3147
+ return ctx;
3148
+ }
3149
+
3150
+ struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) {
3151
+ parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params);
3152
+ if (!ctx) {
3153
+ return nullptr;
3154
+ }
3155
+
3156
+ ctx->state = parakeet_init_state(ctx);
3157
+ if (!ctx->state) {
3158
+ parakeet_free(ctx);
3159
+ return nullptr;
3160
+ }
3161
+
3162
+ return ctx;
3163
+ }
3164
+
3165
+ void parakeet_free_state(struct parakeet_state * state) {
3166
+ if (state) {
3167
+ ggml_backend_buffer_free(state->lstm_state.buffer);
3168
+ ggml_backend_buffer_free(state->pred_out_buffer);
3169
+ ggml_backend_buffer_free(state->enc_out_buffer);
3170
+
3171
+ parakeet_batch_free(state->batch);
3172
+
3173
+ parakeet_sched_free(state->sched_encode);
3174
+ parakeet_sched_free(state->sched_decode);
3175
+
3176
+ for (auto & backend : state->backends) {
3177
+ ggml_backend_free(backend);
3178
+ }
3179
+
3180
+ delete state;
3181
+ }
3182
+ }
3183
+
3184
+ void parakeet_free(struct parakeet_context * ctx) {
3185
+ if (ctx) {
3186
+ for (ggml_context * context : ctx->model.ctxs) {
3187
+ ggml_free(context);
3188
+ }
3189
+
3190
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
3191
+ ggml_backend_buffer_free(buf);
3192
+ }
3193
+
3194
+ parakeet_free_state(ctx->state);
3195
+
3196
+ delete ctx;
3197
+ }
3198
+ }
3199
+
3200
+ void parakeet_free_context_params(struct parakeet_context_params * params) {
3201
+ if (params) {
3202
+ delete params;
3203
+ }
3204
+ }
3205
+
3206
+ void parakeet_free_params(struct parakeet_full_params * params) {
3207
+ if (params) {
3208
+ delete params;
3209
+ }
3210
+ }
3211
+
3212
+ int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) {
3213
+ if (!log_mel_spectrogram(*state,
3214
+ samples,
3215
+ n_samples,
3216
+ PARAKEET_SAMPLE_RATE,
3217
+ ctx->model.hparams.n_fft,
3218
+ PARAKEET_HOP_LENGTH,
3219
+ ctx->model.filters.n_mel,
3220
+ n_threads,
3221
+ ctx->model.filters,
3222
+ false, // debug
3223
+ state->mel,
3224
+ ctx->mel_cache)) {
3225
+ PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3226
+ return -1;
3227
+ }
3228
+
3229
+ return 0;
3230
+ }
3231
+
3232
+ int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) {
3233
+ return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3234
+ }
3235
+
3236
+ int parakeet_set_mel_with_state(
3237
+ struct parakeet_context * ctx,
3238
+ struct parakeet_state * state,
3239
+ const float * data,
3240
+ int n_len,
3241
+ int n_mel) {
3242
+ if (n_mel != ctx->model.filters.n_mel) {
3243
+ PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
3244
+ return -1;
3245
+ }
3246
+
3247
+ state->mel.n_len = n_len;
3248
+ state->mel.n_len_org = n_len;
3249
+ state->mel.n_mel = n_mel;
3250
+
3251
+ state->mel.data.resize(n_len*n_mel);
3252
+ memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
3253
+
3254
+ return 0;
3255
+ }
3256
+
3257
+ int parakeet_set_mel(
3258
+ struct parakeet_context * ctx,
3259
+ const float * data,
3260
+ int n_len,
3261
+ int n_mel) {
3262
+ return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
3263
+ }
3264
+
3265
+ int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) {
3266
+ if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3267
+ PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__);
3268
+ return -1;
3269
+ }
3270
+
3271
+ return 0;
3272
+ }
3273
+
3274
+ int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) {
3275
+ if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3276
+ PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__);
3277
+ return -1;
3278
+ }
3279
+
3280
+ return 0;
3281
+ }
3282
+
3283
+ int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) {
3284
+ const auto res = tokenize(ctx->vocab, text);
3285
+
3286
+ if (n_max_tokens < (int) res.size()) {
3287
+ PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3288
+ return -(int) res.size();
3289
+ }
3290
+
3291
+ for (int i = 0; i < (int) res.size(); i++) {
3292
+ tokens[i] = res[i];
3293
+ }
3294
+
3295
+ return res.size();
3296
+ }
3297
+
3298
+ int parakeet_token_count(struct parakeet_context * ctx, const char * text) {
3299
+ return -parakeet_tokenize(ctx, text, NULL, 0);
3300
+ }
3301
+
3302
+ int parakeet_model_n_vocab(struct parakeet_context * ctx) {
3303
+ return ctx->model.hparams.n_vocab;
3304
+ }
3305
+
3306
+ int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) {
3307
+ return ctx->model.hparams.n_audio_ctx;
3308
+ }
3309
+
3310
+ int parakeet_model_n_audio_state(struct parakeet_context * ctx) {
3311
+ return ctx->model.hparams.n_audio_state;
3312
+ }
3313
+
3314
+ int parakeet_model_n_audio_head(struct parakeet_context * ctx) {
3315
+ return ctx->model.hparams.n_audio_head;
3316
+ }
3317
+
3318
+ int parakeet_model_n_audio_layer(struct parakeet_context * ctx) {
3319
+ return ctx->model.hparams.n_audio_layer;
3320
+ }
3321
+
3322
+ int parakeet_model_n_mels(struct parakeet_context * ctx) {
3323
+ return ctx->model.hparams.n_mels;
3324
+ }
3325
+
3326
+ int parakeet_model_ftype(struct parakeet_context * ctx) {
3327
+ return ctx->model.hparams.ftype;
3328
+ }
3329
+
3330
+ int parakeet_n_len_from_state(struct parakeet_state * state) {
3331
+ return state->mel.n_len_org;
3332
+ }
3333
+
3334
+ int parakeet_n_len(struct parakeet_context * ctx) {
3335
+ return ctx->state->mel.n_len_org;
3336
+ }
3337
+
3338
+ int parakeet_n_vocab(struct parakeet_context * ctx) {
3339
+ return ctx->vocab.n_vocab;
3340
+ }
3341
+
3342
+ int parakeet_n_audio_ctx(struct parakeet_context * ctx) {
3343
+ return ctx->model.hparams.n_audio_ctx;
3344
+ }
3345
+
3346
+ float * parakeet_get_logits(struct parakeet_context * ctx) {
3347
+ return ctx->state->logits.data();
3348
+ }
3349
+
3350
+ float * parakeet_get_logits_from_state(struct parakeet_state * state) {
3351
+ return state->logits.data();
3352
+ }
3353
+
3354
+ const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) {
3355
+ return ctx->vocab.id_to_token.at(token).c_str();
3356
+ }
3357
+
3358
+ int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) {
3359
+ std::string text = sentencepiece_piece_to_text(token_str, is_first);
3360
+
3361
+ if (output == nullptr) {
3362
+ return text.size();
3363
+ }
3364
+
3365
+ int bytes_to_copy = std::min((int)text.size(), max_len - 1);
3366
+ if (bytes_to_copy > 0) {
3367
+ memcpy(output, text.c_str(), bytes_to_copy);
3368
+ output[bytes_to_copy] = '\0';
3369
+ } else if (max_len > 0) {
3370
+ output[0] = '\0';
3371
+ }
3372
+
3373
+ return text.size();
3374
+ }
3375
+
3376
+ parakeet_token parakeet_token_bos(struct parakeet_context * ctx) {
3377
+ return ctx->vocab.token_bos;
3378
+ }
3379
+
3380
+ parakeet_token parakeet_token_unk(struct parakeet_context * ctx) {
3381
+ return ctx->vocab.token_unk;
3382
+ }
3383
+
3384
+ parakeet_token parakeet_token_blank(struct parakeet_context * ctx) {
3385
+ return ctx->vocab.token_blank;
3386
+ }
3387
+
3388
+ struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) {
3389
+ if (ctx->state == nullptr) {
3390
+ return nullptr;
3391
+ }
3392
+ parakeet_timings * timings = new parakeet_timings;
3393
+ timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
3394
+ timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
3395
+ timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
3396
+ return timings;
3397
+ }
3398
+
3399
+ void parakeet_print_timings(struct parakeet_context * ctx) {
3400
+ const int64_t t_end_us = ggml_time_us();
3401
+
3402
+ PARAKEET_LOG_INFO("\n");
3403
+ PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3404
+ if (ctx->state != nullptr) {
3405
+
3406
+ const int32_t n_sample = std::max(1, ctx->state->n_sample);
3407
+ const int32_t n_encode = std::max(1, ctx->state->n_encode);
3408
+ const int32_t n_decode = std::max(1, ctx->state->n_decode);
3409
+ const int32_t n_predict = std::max(1, ctx->state->n_predict);
3410
+
3411
+ PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3412
+ PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3413
+ PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3414
+ PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3415
+ PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3416
+ PARAKEET_LOG_INFO("%s: predict time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_us, n_predict, 1e-3f * ctx->state->t_predict_us / n_predict);
3417
+ PARAKEET_LOG_INFO("%s: - build = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_build_us, n_predict, 1e-3f * ctx->state->t_predict_build_us / n_predict);
3418
+ PARAKEET_LOG_INFO("%s: - alloc = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_alloc_us, n_predict, 1e-3f * ctx->state->t_predict_alloc_us / n_predict);
3419
+ PARAKEET_LOG_INFO("%s: - compute = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_compute_us, n_predict, 1e-3f * ctx->state->t_predict_compute_us / n_predict);
3420
+
3421
+ }
3422
+ PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3423
+ }
3424
+
3425
+ void parakeet_reset_timings(struct parakeet_context * ctx) {
3426
+ ctx->t_start_us = ggml_time_us();
3427
+ if (ctx->state != nullptr) {
3428
+ ctx->state->t_mel_us = 0;
3429
+ ctx->state->t_sample_us = 0;
3430
+ ctx->state->t_encode_us = 0;
3431
+ ctx->state->t_decode_us = 0;
3432
+ ctx->state->t_predict_us = 0;
3433
+ ctx->state->t_predict_build_us = 0;
3434
+ ctx->state->t_predict_alloc_us = 0;
3435
+ ctx->state->t_predict_compute_us = 0;
3436
+
3437
+ ctx->state->n_sample = 0;
3438
+ ctx->state->n_encode = 0;
3439
+ ctx->state->n_decode = 0;
3440
+ ctx->state->n_predict = 0;
3441
+ }
3442
+ }
3443
+
3444
+ const char * parakeet_print_system_info(void) {
3445
+ static std::string s;
3446
+
3447
+ s = "";
3448
+ s += "PARAKEET : ";
3449
+
3450
+ for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
3451
+ auto * reg = ggml_backend_reg_get(i);
3452
+ auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
3453
+ if (get_features_fn) {
3454
+ ggml_backend_feature * features = get_features_fn(reg);
3455
+ s += ggml_backend_reg_name(reg);
3456
+ s += " : ";
3457
+ for (; features->name; features++) {
3458
+ s += features->name;
3459
+ s += " = ";
3460
+ s += features->value;
3461
+ s += " | ";
3462
+ }
3463
+ }
3464
+ }
3465
+ return s.c_str();
3466
+ }
3467
+
3468
+ struct parakeet_context_params * parakeet_context_default_params_by_ref(void) {
3469
+ struct parakeet_context_params params = parakeet_context_default_params();
3470
+
3471
+ struct parakeet_context_params* result = new parakeet_context_params();
3472
+ *result = params;
3473
+ return result;
3474
+ }
3475
+
3476
+ struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) {
3477
+ struct parakeet_full_params params = parakeet_full_default_params(strategy);
3478
+
3479
+ struct parakeet_full_params* result = new parakeet_full_params();
3480
+ *result = params;
3481
+ return result;
3482
+ }
3483
+
3484
+ struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) {
3485
+ struct parakeet_full_params result = {
3486
+ /*.strategy =*/ strategy,
3487
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
3488
+ /*.offset_ms =*/ 0,
3489
+ /*.duration_ms =*/ 0,
3490
+ /*.no_context =*/ true,
3491
+ /*.audio_ctx =*/ 0,
3492
+ /*.new_token_callback =*/ nullptr,
3493
+ /*.new_token_callback_user_data =*/ nullptr,
3494
+ /*.new_segment_callback =*/ nullptr,
3495
+ /*.new_segment_callback_user_data =*/ nullptr,
3496
+ /*.progress_callback =*/ nullptr,
3497
+ /*.progress_callback_user_data =*/ nullptr,
3498
+ /*.encoder_begin_callback =*/ nullptr,
3499
+ /*.encoder_begin_callback_user_data =*/ nullptr,
3500
+ /*.abort_callback =*/ nullptr,
3501
+ /*.abort_callback_user_data =*/ nullptr,
3502
+ };
3503
+
3504
+ return result;
3505
+ }
3506
+
3507
+ static void parakeet_reset_state(struct parakeet_state * state) {
3508
+ state->decoded_tokens.clear();
3509
+ state->decoded_token_data.clear();
3510
+
3511
+ if (state->lstm_state.buffer) {
3512
+ ggml_backend_buffer_clear(state->lstm_state.buffer, 0);
3513
+ }
3514
+
3515
+ }
3516
+
3517
+ // Encode and decode the mel spectrogram already in state, without recomputing it.
3518
+ static int parakeet_chunk_with_state(
3519
+ struct parakeet_context * ctx,
3520
+ struct parakeet_state * state,
3521
+ struct parakeet_full_params params) {
3522
+ return parakeet_chunk(ctx, state, params, nullptr, 0);
3523
+ }
3524
+
3525
+ int parakeet_full_with_state(
3526
+ struct parakeet_context * ctx,
3527
+ struct parakeet_state * state,
3528
+ struct parakeet_full_params params,
3529
+ const float * samples,
3530
+ int n_samples) {
3531
+ state->result_all.clear();
3532
+
3533
+ if (params.no_context) {
3534
+ parakeet_reset_state(state);
3535
+ }
3536
+
3537
+ if (n_samples > 0) {
3538
+ if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3539
+ PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
3540
+ return -2;
3541
+ }
3542
+ }
3543
+
3544
+ const int n_mel_total = state->mel.n_len;
3545
+ const int n_audio_ctx = ctx->model.hparams.n_audio_ctx;
3546
+
3547
+ if (n_mel_total <= n_audio_ctx) {
3548
+ if (params.progress_callback) {
3549
+ params.progress_callback(ctx, state, 0, params.progress_callback_user_data);
3550
+ }
3551
+ return parakeet_chunk_with_state(ctx, state, params);
3552
+ }
3553
+
3554
+ PARAKEET_LOG_DEBUG("%s: audio too long (%d mel > n_audio_ctx=%d), using dynamic encoder graph\n",
3555
+ __func__, n_mel_total, n_audio_ctx);
3556
+
3557
+ if (params.encoder_begin_callback) {
3558
+ if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) {
3559
+ PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false\n", __func__);
3560
+ return -6;
3561
+ }
3562
+ }
3563
+
3564
+ if (params.progress_callback) {
3565
+ params.progress_callback(ctx, state, 0, params.progress_callback_user_data);
3566
+ }
3567
+
3568
+ if (!parakeet_ensure_encode_sched(*ctx, *state, n_mel_total)) {
3569
+ PARAKEET_LOG_ERROR("%s: failed to allocate dynamic encoder graph for %d mel frames\n",
3570
+ __func__, n_mel_total);
3571
+ return -6;
3572
+ }
3573
+
3574
+ state->n_audio_ctx = n_mel_total;
3575
+
3576
+ if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads,
3577
+ params.abort_callback, params.abort_callback_user_data)) {
3578
+ PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__);
3579
+ return -6;
3580
+ }
3581
+
3582
+ if (params.progress_callback) {
3583
+ params.progress_callback(ctx, state, 100, params.progress_callback_user_data);
3584
+ }
3585
+
3586
+ const size_t tokens_before = state->decoded_tokens.size();
3587
+
3588
+ if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, &params)) {
3589
+ PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__);
3590
+ return -7;
3591
+ }
3592
+
3593
+ const size_t tokens_after = state->decoded_tokens.size();
3594
+ const size_t new_token_count = tokens_after - tokens_before;
3595
+
3596
+ if (new_token_count > 0) {
3597
+ std::string text;
3598
+ std::vector<parakeet_token_data> result_tokens;
3599
+
3600
+ for (size_t i = tokens_before; i < tokens_after; i++) {
3601
+ const auto token_id = state->decoded_tokens[i];
3602
+ const char * tok_str = parakeet_token_to_str(ctx, token_id);
3603
+ if (tok_str) {
3604
+ const bool is_first = (tokens_before == 0) && text.empty();
3605
+ text += sentencepiece_piece_to_text(tok_str, is_first);
3606
+ }
3607
+ result_tokens.push_back(state->decoded_token_data[i]);
3608
+ }
3609
+
3610
+ refine_timestamps_tdt(ctx->vocab, result_tokens);
3611
+
3612
+ if (!text.empty()) {
3613
+ parakeet_segment seg;
3614
+ seg.t0 = 0;
3615
+ seg.t1 = state->n_frames;
3616
+ seg.text = text;
3617
+ seg.tokens = result_tokens;
3618
+ state->result_all.push_back(std::move(seg));
3619
+
3620
+ if (params.new_segment_callback) {
3621
+ params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data);
3622
+ }
3623
+ }
3624
+ }
3625
+
3626
+ return 0;
3627
+ }
3628
+
3629
+ int parakeet_full(
3630
+ struct parakeet_context * ctx,
3631
+ struct parakeet_full_params params,
3632
+ const float * samples,
3633
+ int n_samples) {
3634
+ return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples);
3635
+ }
3636
+
3637
+ int parakeet_chunk(
3638
+ struct parakeet_context * ctx,
3639
+ struct parakeet_state * state,
3640
+ struct parakeet_full_params params,
3641
+ const float * samples,
3642
+ int n_samples) {
3643
+
3644
+ if (params.no_context) {
3645
+ parakeet_reset_state(state);
3646
+ }
3647
+
3648
+ if (n_samples > 0) {
3649
+ if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3650
+ PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
3651
+ return -2;
3652
+ }
3653
+ }
3654
+
3655
+ if (params.audio_ctx == 0) {
3656
+ const int total_len = parakeet_n_len_from_state(state);
3657
+ const int model_max_ctx = parakeet_n_audio_ctx(ctx);
3658
+ params.audio_ctx = std::min(total_len, model_max_ctx);
3659
+ PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx);
3660
+ }
3661
+ state->n_audio_ctx = params.audio_ctx;
3662
+
3663
+ const int n_frames = parakeet_n_len_from_state(state);
3664
+
3665
+ if (!parakeet_ensure_encode_sched(*ctx, *state, state->n_audio_ctx)) {
3666
+ PARAKEET_LOG_ERROR("%s: failed to allocate encoder graph for %d mel frames\n",
3667
+ __func__, state->n_audio_ctx);
3668
+ return -6;
3669
+ }
3670
+
3671
+ if (params.encoder_begin_callback) {
3672
+ if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) {
3673
+ PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
3674
+ return -6;
3675
+ }
3676
+ }
3677
+ if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
3678
+ PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__);
3679
+ return -6;
3680
+ }
3681
+
3682
+ const size_t tokens_before = state->decoded_tokens.size();
3683
+
3684
+ if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, &params)) {
3685
+ PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__);
3686
+ return -7;
3687
+ }
3688
+
3689
+ const size_t tokens_after = state->decoded_tokens.size();
3690
+ const size_t new_token_count = tokens_after - tokens_before;
3691
+
3692
+ if (new_token_count > 0) {
3693
+ std::string text;
3694
+ std::vector<parakeet_token_data> result_tokens;
3695
+
3696
+ for (size_t i = tokens_before; i < tokens_after; i++) {
3697
+ const auto token_id = state->decoded_tokens[i];
3698
+ const char * token_str = parakeet_token_to_str(ctx, token_id);
3699
+ if (token_str) {
3700
+ const bool is_first_piece = (tokens_before == 0) && text.empty();
3701
+ text += sentencepiece_piece_to_text(token_str, is_first_piece);
3702
+ }
3703
+
3704
+ // Use the stored token data from parakeet_decode
3705
+ result_tokens.push_back(state->decoded_token_data[i]);
3706
+ }
3707
+
3708
+ refine_timestamps_tdt(ctx->vocab, result_tokens);
3709
+
3710
+ if (!text.empty()) {
3711
+ parakeet_segment segment;
3712
+ segment.t0 = 0; // Caller tracks timing
3713
+ segment.t1 = n_frames;
3714
+ segment.text = text;
3715
+ segment.tokens = result_tokens;
3716
+
3717
+ state->result_all.push_back(std::move(segment));
3718
+
3719
+ if (params.new_segment_callback) {
3720
+ params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data);
3721
+ }
3722
+ }
3723
+ }
3724
+
3725
+ return 0;
3726
+ }
3727
+
3728
+ int parakeet_full_n_segments_from_state(struct parakeet_state * state) {
3729
+ return state->result_all.size();
3730
+ }
3731
+
3732
+ int parakeet_full_n_segments(struct parakeet_context * ctx) {
3733
+ return ctx->state->result_all.size();
3734
+ }
3735
+
3736
+ int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) {
3737
+ return state->result_all[i_segment].t0;
3738
+ }
3739
+
3740
+ int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) {
3741
+ return state->result_all[i_segment].t1;
3742
+ }
3743
+
3744
+ int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) {
3745
+ return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment);
3746
+ }
3747
+
3748
+ int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) {
3749
+ return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment);
3750
+ }
3751
+
3752
+ const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) {
3753
+ return state->result_all[i_segment].text.c_str();
3754
+ }
3755
+
3756
+ const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) {
3757
+ return ctx->state->result_all[i_segment].text.c_str();
3758
+ }
3759
+
3760
+ int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) {
3761
+ return state->result_all[i_segment].tokens.size();
3762
+ }
3763
+
3764
+ int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) {
3765
+ return ctx->state->result_all[i_segment].tokens.size();
3766
+ }
3767
+
3768
+ const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) {
3769
+ return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
3770
+ }
3771
+
3772
+ const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) {
3773
+ return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
3774
+ }
3775
+
3776
+ parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) {
3777
+ return state->result_all[i_segment].tokens[i_token].id;
3778
+ }
3779
+
3780
+ parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) {
3781
+ return ctx->state->result_all[i_segment].tokens[i_token].id;
3782
+ }
3783
+
3784
+ struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) {
3785
+ return state->result_all[i_segment].tokens[i_token];
3786
+ }
3787
+
3788
+ struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) {
3789
+ return ctx->state->result_all[i_segment].tokens[i_token];
3790
+ }
3791
+
3792
+ float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) {
3793
+ return state->result_all[i_segment].tokens[i_token].p;
3794
+ }
3795
+
3796
+ float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) {
3797
+ return ctx->state->result_all[i_segment].tokens[i_token].p;
3798
+ }
3799
+
3800
+ void parakeet_log_set(ggml_log_callback log_callback, void * user_data) {
3801
+ g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default;
3802
+ g_state.log_callback_user_data = user_data;
3803
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
3804
+ }
3805
+
3806
+ const char * parakeet_version(void) {
3807
+ return PARAKEET_VERSION;
3808
+ }
3809
+
3810
+ GGML_ATTRIBUTE_FORMAT(2, 3)
3811
+ static void parakeet_log_internal(ggml_log_level level, const char * format, ...) {
3812
+ va_list args;
3813
+ va_start(args, format);
3814
+ char buffer[1024];
3815
+ int len = vsnprintf(buffer, 1024, format, args);
3816
+ if (len < 1024) {
3817
+ g_state.log_callback(level, buffer, g_state.log_callback_user_data);
3818
+ } else {
3819
+ char* buffer2 = new char[len+1];
3820
+ vsnprintf(buffer2, len+1, format, args);
3821
+ buffer2[len] = 0;
3822
+ g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
3823
+ delete[] buffer2;
3824
+ }
3825
+ va_end(args);
3826
+ }
3827
+
3828
+ static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
3829
+ (void) level;
3830
+ (void) user_data;
3831
+ #ifndef PARAKEET_DEBUG
3832
+ if (level == GGML_LOG_LEVEL_DEBUG) {
3833
+ return;
3834
+ }
3835
+ #endif
3836
+ fputs(text, stderr);
3837
+ fflush(stderr);
3838
+ }