whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,276 @@
1
+ #include "set-rows.cuh"
2
+ #include "cpy-utils.cuh"
3
+
4
+ typedef void (*set_rows_kernel_t)(const char * src, char * dst);
5
+
6
+ // Generic quantized set_rows kernel template
7
+ template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
8
+ static __global__ void k_set_rows_quant(
9
+ const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
10
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
11
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
12
+ const int64_t s01, const int64_t s02, const int64_t s03,
13
+ const int64_t s10, const int64_t s11, const int64_t s12,
14
+ const int64_t s1, const int64_t s2, const int64_t s3) {
15
+
16
+ const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
17
+ const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
18
+
19
+ if (i >= ne_total) {
20
+ return;
21
+ }
22
+
23
+ const int64_t i_base = i * qk;
24
+ const int64_t i03 = i_base / (ne00 * ne01 * ne02);
25
+ const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
26
+ const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
27
+ const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
28
+
29
+ const int64_t i12 = i03 % ne12;
30
+ const int64_t i11 = i02 % ne11;
31
+ const int64_t i10 = i01;
32
+
33
+ const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
34
+
35
+ const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
36
+ block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
37
+
38
+ const float * src_block = src0_row + i00;
39
+ block_type * dst_block = dst_row_ptr + i00 / qk;
40
+
41
+ quantize_func(src_block, dst_block);
42
+
43
+ GGML_UNUSED(ne10);
44
+ GGML_UNUSED(ne13);
45
+ }
46
+
47
+ // Template dispatch function for quantized set_rows
48
+ template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
49
+ static void set_rows_cuda_quant(
50
+ const float * src0_d, const idx_t * src1_d, block_type * dst_d,
51
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
52
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
53
+ const size_t nb01, const size_t nb02, const size_t nb03,
54
+ const size_t nb10, const size_t nb11, const size_t nb12,
55
+ const size_t nb1, const size_t nb2, const size_t nb3,
56
+ cudaStream_t stream) {
57
+
58
+ GGML_ASSERT(ne00 % qk == 0);
59
+ const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
60
+ const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
61
+ const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
62
+ const dim3 grid_size(num_blocks);
63
+
64
+ const int64_t s01 = nb01/sizeof(float);
65
+ const int64_t s02 = nb02/sizeof(float);
66
+ const int64_t s03 = nb03/sizeof(float);
67
+ const int64_t s10 = nb10/sizeof(idx_t);
68
+ const int64_t s11 = nb11/sizeof(idx_t);
69
+ const int64_t s12 = nb12/sizeof(idx_t);
70
+ const int64_t s1 = nb1;
71
+ const int64_t s2 = nb2;
72
+ const int64_t s3 = nb3;
73
+
74
+ if (ne_total > 0) {
75
+ k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
76
+ src0_d, src1_d, dst_d,
77
+ ne00, ne01, ne02, ne03,
78
+ ne10, ne11, ne12, ne13,
79
+ s01, s02, s03,
80
+ s10, s11, s12,
81
+ s1, s2, s3);
82
+ }
83
+ }
84
+
85
+ template<typename src_t, typename idx_t, typename dst_t>
86
+ static __global__ void k_set_rows(
87
+ const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
88
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
89
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
90
+ const int64_t s01, const int64_t s02, const int64_t s03,
91
+ const int64_t s10, const int64_t s11, const int64_t s12,
92
+ const int64_t s1, const int64_t s2, const int64_t s3) {
93
+
94
+ const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
95
+ const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
96
+
97
+ if (i >= ne_total) {
98
+ return;
99
+ }
100
+
101
+ const int64_t i03 = i / (ne00 * ne01 * ne02);
102
+ const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
103
+ const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
104
+ const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
105
+
106
+ const int64_t i12 = i03 % ne12;
107
+ const int64_t i11 = i02 % ne11;
108
+ const int64_t i10 = i01;
109
+
110
+ const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
111
+
112
+ const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
113
+ dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
114
+
115
+ dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
116
+
117
+ GGML_UNUSED(ne10);
118
+ GGML_UNUSED(ne13);
119
+ }
120
+
121
+ template<typename src_t, typename idx_t, typename dst_t>
122
+ static void set_rows_cuda(
123
+ const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
124
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
125
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
126
+ const size_t nb01, const size_t nb02, const size_t nb03,
127
+ const size_t nb10, const size_t nb11, const size_t nb12,
128
+ const size_t nb1, const size_t nb2, const size_t nb3,
129
+ cudaStream_t stream) {
130
+
131
+ const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
132
+ const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
133
+ const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
134
+ const dim3 grid_size(num_blocks);
135
+
136
+
137
+ const int64_t s01 = nb01/sizeof(src_t);
138
+ const int64_t s02 = nb02/sizeof(src_t);
139
+ const int64_t s03 = nb03/sizeof(src_t);
140
+ const int64_t s10 = nb10/sizeof(idx_t);
141
+ const int64_t s11 = nb11/sizeof(idx_t);
142
+ const int64_t s12 = nb12/sizeof(idx_t);
143
+ const int64_t s1 = nb1/sizeof(dst_t);
144
+ const int64_t s2 = nb2/sizeof(dst_t);
145
+ const int64_t s3 = nb3/sizeof(dst_t);
146
+
147
+ if (ne_total > 0) {
148
+ k_set_rows<<<grid_size, block_size, 0, stream>>>(
149
+ src0_d, src1_d, dst_d,
150
+ ne00, ne01, ne02, ne03,
151
+ ne10, ne11, ne12, ne13,
152
+ s01, s02, s03,
153
+ s10, s11, s12,
154
+ s1, s2, s3);
155
+ }
156
+ }
157
+
158
+ template<typename src_t, typename idx_t>
159
+ static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
160
+ const src_t * src0_d = (const src_t *)src0->data;
161
+ const idx_t * src1_d = (const idx_t *)src1->data;
162
+
163
+ GGML_TENSOR_BINARY_OP_LOCALS
164
+
165
+ cudaStream_t stream = ctx.stream();
166
+
167
+
168
+ if (dst->type == GGML_TYPE_F32) {
169
+ set_rows_cuda(
170
+ src0_d, src1_d, (float*)dst->data,
171
+ ne00, ne01, ne02, ne03,
172
+ ne10, ne11, ne12, ne13,
173
+ nb01, nb02, nb03,
174
+ nb10, nb11, nb12,
175
+ nb1, nb2, nb3,
176
+ stream
177
+ );
178
+ } else if (dst->type == GGML_TYPE_F16) {
179
+ set_rows_cuda(
180
+ src0_d, src1_d, (half*)dst->data,
181
+ ne00, ne01, ne02, ne03,
182
+ ne10, ne11, ne12, ne13,
183
+ nb01, nb02, nb03,
184
+ nb10, nb11, nb12,
185
+ nb1, nb2, nb3,
186
+ stream
187
+ );
188
+ } else if (dst->type == GGML_TYPE_BF16) {
189
+ set_rows_cuda(
190
+ src0_d, src1_d, (nv_bfloat16*)dst->data,
191
+ ne00, ne01, ne02, ne03,
192
+ ne10, ne11, ne12, ne13,
193
+ nb01, nb02, nb03,
194
+ nb10, nb11, nb12,
195
+ nb1, nb2, nb3,
196
+ stream
197
+ );
198
+ } else if (dst->type == GGML_TYPE_Q4_0) {
199
+ set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
200
+ src0_d, src1_d, (block_q4_0*)dst->data,
201
+ ne00, ne01, ne02, ne03,
202
+ ne10, ne11, ne12, ne13,
203
+ nb01, nb02, nb03,
204
+ nb10, nb11, nb12,
205
+ nb1, nb2, nb3,
206
+ stream
207
+ );
208
+ } else if (dst->type == GGML_TYPE_Q4_1) {
209
+ set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
210
+ src0_d, src1_d, (block_q4_1*)dst->data,
211
+ ne00, ne01, ne02, ne03,
212
+ ne10, ne11, ne12, ne13,
213
+ nb01, nb02, nb03,
214
+ nb10, nb11, nb12,
215
+ nb1, nb2, nb3,
216
+ stream
217
+ );
218
+ } else if (dst->type == GGML_TYPE_Q5_0) {
219
+ set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
220
+ src0_d, src1_d, (block_q5_0*)dst->data,
221
+ ne00, ne01, ne02, ne03,
222
+ ne10, ne11, ne12, ne13,
223
+ nb01, nb02, nb03,
224
+ nb10, nb11, nb12,
225
+ nb1, nb2, nb3,
226
+ stream
227
+ );
228
+ } else if (dst->type == GGML_TYPE_Q5_1) {
229
+ set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
230
+ src0_d, src1_d, (block_q5_1*)dst->data,
231
+ ne00, ne01, ne02, ne03,
232
+ ne10, ne11, ne12, ne13,
233
+ nb01, nb02, nb03,
234
+ nb10, nb11, nb12,
235
+ nb1, nb2, nb3,
236
+ stream
237
+ );
238
+ } else if (dst->type == GGML_TYPE_Q8_0) {
239
+ set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
240
+ src0_d, src1_d, (block_q8_0*)dst->data,
241
+ ne00, ne01, ne02, ne03,
242
+ ne10, ne11, ne12, ne13,
243
+ nb01, nb02, nb03,
244
+ nb10, nb11, nb12,
245
+ nb1, nb2, nb3,
246
+ stream
247
+ );
248
+ } else if (dst->type == GGML_TYPE_IQ4_NL) {
249
+ set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
250
+ src0_d, src1_d, (block_iq4_nl*)dst->data,
251
+ ne00, ne01, ne02, ne03,
252
+ ne10, ne11, ne12, ne13,
253
+ nb01, nb02, nb03,
254
+ nb10, nb11, nb12,
255
+ nb1, nb2, nb3,
256
+ stream
257
+ );
258
+ } else {
259
+ GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
260
+ }
261
+ }
262
+
263
+
264
+ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265
+ const ggml_tensor * src0 = dst->src[0];
266
+ const ggml_tensor * src1 = dst->src[1];
267
+
268
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
269
+ GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
270
+
271
+ if (src1->type == GGML_TYPE_I64) {
272
+ set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
273
+ } else {
274
+ set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
275
+ }
276
+ }
@@ -0,0 +1,7 @@
1
+ #pragma once
2
+
3
+ #include "common.cuh"
4
+
5
+ #define CUDA_SET_ROWS_BLOCK_SIZE 256
6
+
7
+ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -0,0 +1,34 @@
1
+ #include "softcap.cuh"
2
+
3
+ static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
4
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
5
+
6
+ if (i >= k) {
7
+ return;
8
+ }
9
+
10
+ dst[i] = tanhf(scale * x[i]) * softcap;
11
+ }
12
+
13
+ static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
14
+ const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
15
+ softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
16
+ }
17
+
18
+ // fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
19
+ void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {
20
+ const ggml_tensor * src0 = src->src[0];
21
+ const float * src0_d = (const float *)src0->data;
22
+ float * dst_d = (float *)dst->data;
23
+ cudaStream_t stream = ctx.stream();
24
+
25
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
26
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
27
+
28
+ float scale;
29
+ float softcap;
30
+ memcpy(&scale, (float *) src->op_params + 0, sizeof(float));
31
+ memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));
32
+
33
+ softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
34
+ }
@@ -0,0 +1,5 @@
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_SOFTCAP_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);
@@ -2,6 +2,7 @@
2
2
  #include "ggml.h"
3
3
  #include "softmax.cuh"
4
4
  #include <cstdint>
5
+ #include <utility>
5
6
 
6
7
  template <typename T>
7
8
  static __device__ __forceinline__ float t2f32(T val) {
@@ -13,6 +14,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
13
14
  return __half2float(val);
14
15
  }
15
16
 
17
+ struct soft_max_params {
18
+
19
+ int64_t nheads;
20
+ uint32_t n_head_log2;
21
+ int64_t ncols;
22
+ int64_t nrows_x;
23
+ int64_t nrows_y;
24
+ int64_t ne00;
25
+ int64_t ne01;
26
+ int64_t ne02;
27
+ int64_t ne03;
28
+ int64_t nb11;
29
+ int64_t nb12;
30
+ int64_t nb13;
31
+
32
+ int64_t ne12;
33
+ int64_t ne13;
34
+ float scale;
35
+ float max_bias;
36
+ float m0;
37
+ float m1;
38
+ };
39
+
16
40
  // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
17
41
  // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
18
42
  #ifdef __clang__
@@ -21,16 +45,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
21
45
  #endif // __clang__
22
46
  template <bool use_shared, int ncols_template, int block_size_template, typename T>
23
47
  static __global__ void soft_max_f32(
24
- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25
- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26
- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
48
+ const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
49
+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
27
50
 
28
51
  const int tid = threadIdx.x;
29
- const int rowx = blockIdx.x;
30
- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
52
+
53
+ const int64_t i03 = blockIdx.z;
54
+ const int64_t i02 = blockIdx.y;
55
+ const int64_t i01 = blockIdx.x;
56
+
57
+ //TODO: noncontigous inputs/outputs
58
+ const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
59
+
60
+ const int64_t i11 = i01;
61
+ const int64_t i12 = i02 % p.ne12;
62
+ const int64_t i13 = i03 % p.ne13;
31
63
 
32
64
  x += int64_t(rowx)*ncols;
33
- mask += int64_t(rowy)*ncols * (mask != nullptr);
65
+ mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
34
66
  dst += int64_t(rowx)*ncols;
35
67
 
36
68
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
@@ -38,14 +70,14 @@ static __global__ void soft_max_f32(
38
70
  const int warp_id = threadIdx.x / WARP_SIZE;
39
71
  const int lane_id = threadIdx.x % WARP_SIZE;
40
72
 
41
- const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
73
+ const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
42
74
 
43
75
  extern __shared__ float data_soft_max_f32[];
44
76
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
45
77
  // shared memory buffer to cache values between iterations:
46
78
  float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
47
79
 
48
- float max_val = -INFINITY;
80
+ float max_val = sinks ? sinks[i02] : -INFINITY;
49
81
 
50
82
  #pragma unroll
51
83
  for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -55,7 +87,7 @@ static __global__ void soft_max_f32(
55
87
  break;
56
88
  }
57
89
 
58
- const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
90
+ const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
59
91
 
60
92
  vals[col] = val;
61
93
  max_val = max(max_val, val);
@@ -111,6 +143,10 @@ static __global__ void soft_max_f32(
111
143
  tmp = warp_reduce_sum(tmp);
112
144
  }
113
145
 
146
+ if (sinks) {
147
+ tmp += expf(sinks[i02] - max_val);
148
+ }
149
+
114
150
  const float inv_sum = 1.0f / tmp;
115
151
 
116
152
  #pragma unroll
@@ -150,64 +186,58 @@ static __global__ void soft_max_back_f32(
150
186
  }
151
187
  }
152
188
 
189
+ template<int... Ns, typename T>
190
+ static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
191
+ const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
192
+ {
193
+ const int id = ggml_cuda_get_device();
194
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
195
+
196
+ auto launch_kernel = [=](auto I) -> bool {
197
+ constexpr int ncols = decltype(I)::value;
198
+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
199
+
200
+ if (p.ncols == ncols) {
201
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
202
+ soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
203
+ (x, mask, sinks, dst, p);
204
+ return true;
205
+ }
206
+ return false;
207
+ };
208
+
209
+ // unary fold over launch_kernel
210
+ if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
211
+ return;
212
+ }
213
+
214
+ //default case
215
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
216
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
217
+ }
218
+
219
+
153
220
  template<typename T>
154
- static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
221
+ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
155
222
  int nth = WARP_SIZE;
223
+ const int64_t ncols_x = params.ncols;
224
+
156
225
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
157
226
  const dim3 block_dims(nth, 1, 1);
158
- const dim3 block_nums(nrows_x, 1, 1);
227
+ const dim3 block_nums(params.ne01, params.ne02, params.ne03);
159
228
  const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
160
229
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
161
230
 
162
- const uint32_t n_head = nrows_x/nrows_y;
163
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
164
231
 
165
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
166
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
232
+ const int id = ggml_cuda_get_device();
233
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
167
234
 
168
- // FIXME: this limit could be raised by ~2-4x on Ampere or newer
169
- if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
170
- switch (ncols_x) {
171
- case 32:
172
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
173
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
174
- break;
175
- case 64:
176
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
177
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
178
- break;
179
- case 128:
180
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
181
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
182
- break;
183
- case 256:
184
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
185
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
186
- break;
187
- case 512:
188
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
189
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
190
- break;
191
- case 1024:
192
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
193
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
194
- break;
195
- case 2048:
196
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
197
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
198
- break;
199
- case 4096:
200
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
201
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
202
- break;
203
- default:
204
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
205
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
206
- break;
207
- }
235
+
236
+ if (nbytes_shared <= smpbo) {
237
+ launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
208
238
  } else {
209
239
  const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
210
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
240
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
211
241
  }
212
242
  }
213
243
 
@@ -223,9 +253,11 @@ static void soft_max_back_f32_cuda(
223
253
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
224
254
  const ggml_tensor * src0 = dst->src[0];
225
255
  const ggml_tensor * src1 = dst->src[1];
256
+ const ggml_tensor * src2 = dst->src[2];
226
257
 
227
258
  const float * src0_d = (const float *) src0->data;
228
259
  const void * src1_d = src1 ? (const void *) src1->data : nullptr;
260
+ const void * src2_d = src2 ? (const void *) src2->data : nullptr;
229
261
  float * dst_d = (float *) dst->data;
230
262
 
231
263
  cudaStream_t stream = ctx.stream();
@@ -235,10 +267,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
235
267
 
236
268
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
237
269
 
238
- const int64_t ne00 = src0->ne[0];
239
270
  const int64_t nrows_x = ggml_nrows(src0);
240
271
  const int64_t nrows_y = src0->ne[1];
241
272
 
273
+ const int64_t ne00 = src0->ne[0];
274
+
242
275
  float scale = 1.0f;
243
276
  float max_bias = 0.0f;
244
277
 
@@ -247,10 +280,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247
280
 
248
281
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
249
282
 
283
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
284
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
285
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
286
+
287
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
288
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
289
+
290
+ const uint32_t n_head = src0->ne[2];
291
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
292
+
293
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
294
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
295
+
296
+
297
+ soft_max_params params = {};
298
+ params.nheads = src0->ne[2];
299
+ params.n_head_log2 = n_head_log2;
300
+ params.ncols = ne00;
301
+ params.nrows_x = nrows_x;
302
+ params.nrows_y = nrows_y;
303
+ params.ne00 = src0->ne[0];
304
+ params.ne01 = src0->ne[1];
305
+ params.ne02 = src0->ne[2];
306
+ params.ne03 = src0->ne[3];
307
+ params.nb11 = nb11;
308
+ params.nb12 = nb12;
309
+ params.nb13 = nb13;
310
+ params.ne12 = ne12;
311
+ params.ne13 = ne13;
312
+ params.scale = scale;
313
+ params.max_bias = max_bias;
314
+ params.m0 = m0;
315
+ params.m1 = m1;
316
+
250
317
  if (use_f16) {
251
- soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
318
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
252
319
  } else {
253
- soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
320
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
254
321
  }
255
322
  }
256
323
 
@@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
107
107
  if (nc == 4) {
108
108
  ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109
109
  dst, dst_nb0, dst_nb1, dst_nb2, n_t);
110
+ } else if (nc == 3) {
111
+ ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
112
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t);
110
113
  } else {
111
- GGML_ABORT("Only support kernel size = 4 now.");
114
+ GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
112
115
  }
113
116
  } else {
114
117
  if (nc == 4) {
@@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
116
119
  dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
117
120
  ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
118
121
  src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
122
+ } else if (nc == 3) {
123
+ const int64_t split_n_t = 32;
124
+ dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
125
+ ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
126
+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
119
127
  } else {
120
- GGML_ABORT("Only support kernel size = 4 right now.");
128
+ GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
121
129
  }
122
130
  }
123
131
  }