whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
13
13
  case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
14
14
  case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
15
15
  case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
16
+ case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
16
17
  case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
17
18
  case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
18
19
  case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
38
39
  case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
39
40
  case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
40
41
  case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
42
+ case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
41
43
  case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
42
44
  case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
43
45
  case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
@@ -139,9 +141,10 @@ template <ggml_type type, int ncols_dst>
139
141
  __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
140
142
  static __global__ void mul_mat_vec_q(
141
143
  const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
142
- const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
143
- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
144
- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
144
+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145
+ const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146
+ const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147
+ const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
145
148
 
146
149
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
147
150
  constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -159,12 +162,12 @@ static __global__ void mul_mat_vec_q(
159
162
  constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
160
163
 
161
164
  // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
162
- const int channel_dst = blockIdx.y;
163
- const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
164
- const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
165
- const int sample_dst = blockIdx.z;
166
- const int sample_x = sample_dst / sample_ratio;
167
- const int sample_y = sample_dst;
165
+ const uint32_t channel_dst = blockIdx.y;
166
+ const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
167
+ const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
168
+ const uint32_t sample_dst = blockIdx.z;
169
+ const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
170
+ const uint32_t sample_y = sample_dst;
168
171
 
169
172
  // partial sum for each thread
170
173
  float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@@ -217,7 +220,7 @@ static __global__ void mul_mat_vec_q(
217
220
  tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
218
221
  }
219
222
 
220
- if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) {
223
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
221
224
  dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
222
225
  }
223
226
  }
@@ -245,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
245
248
  GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
246
249
  GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
247
250
 
248
- const int channel_ratio = nchannels_dst / nchannels_x;
249
- const int sample_ratio = nsamples_dst / nsamples_x;
251
+ const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
252
+ const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
253
+ const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
250
254
 
251
255
  const int device = ggml_cuda_get_device();
252
256
  const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -254,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
254
258
 
255
259
  GGML_ASSERT(!ids || ncols_dst == 1);
256
260
  switch (ncols_dst) {
257
- case 1:
258
- {
261
+ case 1: {
259
262
  constexpr int c_ncols_dst = 1;
260
263
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
261
264
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
262
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
263
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
264
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
265
- break;
266
- }
267
- case 2:
268
- {
265
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
266
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
267
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
268
+ } break;
269
+ case 2: {
269
270
  constexpr int c_ncols_dst = 2;
270
271
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
271
272
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
272
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
273
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
274
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
275
- break;
276
- }
277
- case 3:
278
- {
273
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
274
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
275
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
276
+ } break;
277
+ case 3: {
279
278
  constexpr int c_ncols_dst = 3;
280
279
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
281
280
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
282
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
283
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
284
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
285
- break;
286
- }
287
- case 4:
288
- {
281
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
282
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
283
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
284
+ } break;
285
+ case 4: {
289
286
  constexpr int c_ncols_dst = 4;
290
287
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
291
288
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
292
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
293
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
294
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
295
- break;
296
- }
297
- case 5:
298
- {
289
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
290
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
291
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
292
+ } break;
293
+ case 5: {
299
294
  constexpr int c_ncols_dst = 5;
300
295
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
301
296
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
302
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
303
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
304
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
305
- break;
306
- }
307
- case 6:
308
- {
297
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
298
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
299
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
300
+ } break;
301
+ case 6: {
309
302
  constexpr int c_ncols_dst = 6;
310
303
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
311
304
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
312
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
313
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
314
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
315
- break;
316
- }
317
- case 7:
318
- {
305
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
306
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
307
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
308
+ } break;
309
+ case 7: {
319
310
  constexpr int c_ncols_dst = 7;
320
311
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
321
312
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
322
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
323
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
324
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
325
- break;
326
- }
327
- case 8:
328
- {
313
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
314
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
315
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
316
+ } break;
317
+ case 8: {
329
318
  constexpr int c_ncols_dst = 8;
330
319
  std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
331
320
  mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
332
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
333
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
334
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
335
- break;
336
- }
321
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
322
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
323
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
324
+ } break;
337
325
  default:
338
326
  GGML_ABORT("fatal error");
339
327
  break;
@@ -384,6 +372,13 @@ static void mul_mat_vec_q_switch_type(
384
372
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
385
373
  stream);
386
374
  break;
375
+ case GGML_TYPE_MXFP4:
376
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
377
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
378
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
379
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
380
+ stream);
381
+ break;
387
382
  case GGML_TYPE_Q2_K:
388
383
  mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
389
384
  (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
@@ -587,9 +582,5 @@ void ggml_cuda_op_mul_mat_vec_q(
587
582
  src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
588
583
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
589
584
 
590
- GGML_UNUSED(src1);
591
- GGML_UNUSED(dst);
592
- GGML_UNUSED(src1_ddf_i);
593
- GGML_UNUSED(src1_ncols);
594
- GGML_UNUSED(src1_padded_row_size);
585
+ GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
595
586
  }
@@ -104,10 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
104
104
  }
105
105
  }
106
106
 
107
- template <int block_size>
108
- static __global__ void rms_norm_f32(
109
- const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110
- const int64_t stride_sample, const float eps) {
107
+ template <int block_size, bool do_multiply = false, bool do_add = false>
108
+ static __global__ void rms_norm_f32(const float * x,
109
+ float * dst,
110
+ const int ncols,
111
+ const int64_t stride_row,
112
+ const int64_t stride_channel,
113
+ const int64_t stride_sample,
114
+ const float eps,
115
+ const float * mul = nullptr,
116
+ const int64_t mul_stride_row = 0,
117
+ const int64_t mul_stride_channel = 0,
118
+ const int64_t mul_stride_sample = 0,
119
+ const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
120
+ const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
121
+ const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
122
+ const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
123
+ const float * add = nullptr,
124
+ const int64_t add_stride_row = 0,
125
+ const int64_t add_stride_channel = 0,
126
+ const int64_t add_stride_sample = 0,
127
+ const uint3 add_ncols_packed = make_uint3(0, 0, 0),
128
+ const uint3 add_nrows_packed = make_uint3(0, 0, 0),
129
+ const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
130
+ const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
111
131
  const int nrows = gridDim.x;
112
132
  const int nchannels = gridDim.y;
113
133
 
@@ -116,9 +136,25 @@ static __global__ void rms_norm_f32(
116
136
  const int sample = blockIdx.z;
117
137
  const int tid = threadIdx.x;
118
138
 
139
+ static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
140
+
119
141
  x += sample*stride_sample + channel*stride_channel + row*stride_row;
120
142
  dst += ((sample*nchannels + channel)*nrows + row)*ncols;
121
143
 
144
+ if constexpr (do_multiply) {
145
+ const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
146
+ const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
147
+ const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
148
+ mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
149
+ }
150
+
151
+ if constexpr (do_add) {
152
+ const int add_row = fastmodulo(row, add_nrows_packed);
153
+ const int add_channel = fastmodulo(channel, add_nchannels_packed);
154
+ const int add_sample = fastmodulo(sample, add_nsamples_packed);
155
+ add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
156
+ }
157
+
122
158
  float tmp = 0.0f; // partial sum for thread in warp
123
159
 
124
160
  for (int col = tid; col < ncols; col += block_size) {
@@ -129,15 +165,18 @@ static __global__ void rms_norm_f32(
129
165
  // sum up partial sums
130
166
  tmp = warp_reduce_sum(tmp);
131
167
  if constexpr (block_size > WARP_SIZE) {
132
- static_assert(block_size == 1024, "unexpected block_size");
168
+ static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
133
169
  __shared__ float s_sum[32];
134
- const int warp_id = threadIdx.x / WARP_SIZE;
135
- const int lane_id = threadIdx.x % WARP_SIZE;
170
+ const int warp_id = tid / WARP_SIZE;
171
+ const int lane_id = tid % WARP_SIZE;
136
172
  if (lane_id == 0) {
137
173
  s_sum[warp_id] = tmp;
138
174
  }
139
175
  __syncthreads();
140
- tmp = s_sum[lane_id];
176
+ tmp = 0.0f;
177
+ if (lane_id < (block_size / WARP_SIZE)) {
178
+ tmp = s_sum[lane_id];
179
+ }
141
180
  tmp = warp_reduce_sum(tmp);
142
181
  }
143
182
 
@@ -145,7 +184,16 @@ static __global__ void rms_norm_f32(
145
184
  const float scale = rsqrtf(mean + eps);
146
185
 
147
186
  for (int col = tid; col < ncols; col += block_size) {
148
- dst[col] = scale * x[col];
187
+ if constexpr (do_multiply && do_add) {
188
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
189
+ const int add_col = fastmodulo(col, add_ncols_packed);
190
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
191
+ } else if constexpr (do_multiply) {
192
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
193
+ dst[col] = scale * x[col] * mul[mul_col];
194
+ } else {
195
+ dst[col] = scale * x[col];
196
+ }
149
197
  }
150
198
  }
151
199
 
@@ -309,11 +357,87 @@ static void rms_norm_f32_cuda(
309
357
  const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
310
358
  const dim3 blocks_num(nrows, nchannels, nsamples);
311
359
  if (ncols < 1024) {
312
- const dim3 block_dims(WARP_SIZE, 1, 1);
313
- rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
360
+ const dim3 block_dims(256, 1, 1);
361
+ rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
314
362
  } else {
315
363
  const dim3 block_dims(1024, 1, 1);
316
- rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
364
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
365
+ }
366
+ }
367
+
368
+ static void rms_norm_mul_f32_cuda(const float * x,
369
+ const float * mul,
370
+ const float * add,
371
+ float * dst,
372
+ const int ncols,
373
+ const int nrows,
374
+ const int nchannels,
375
+ const int nsamples,
376
+ const int64_t stride_row,
377
+ const int64_t stride_channel,
378
+ const int64_t stride_sample,
379
+ const int64_t mul_stride_row,
380
+ const int64_t mul_stride_channel,
381
+ const int64_t mul_stride_sample,
382
+ const uint32_t mul_ncols,
383
+ const uint32_t mul_nrows,
384
+ const uint32_t mul_nchannels,
385
+ const uint32_t mul_nsamples,
386
+ const int64_t add_stride_row,
387
+ const int64_t add_stride_channel,
388
+ const int64_t add_stride_sample,
389
+ const uint32_t add_ncols,
390
+ const uint32_t add_nrows,
391
+ const uint32_t add_nchannels,
392
+ const uint32_t add_nsamples,
393
+ const float eps,
394
+ cudaStream_t stream) {
395
+ const dim3 blocks_num(nrows, nchannels, nsamples);
396
+ if (mul == nullptr) {
397
+ rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
398
+ return;
399
+ }
400
+ if (add == nullptr) {
401
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
402
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
403
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
404
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
405
+ if (ncols < 1024) {
406
+ const dim3 block_dims(256, 1, 1);
407
+ rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
408
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
409
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
410
+ } else {
411
+ const dim3 block_dims(1024, 1, 1);
412
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
413
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
414
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
415
+ }
416
+ } else {
417
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
418
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
419
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
420
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
421
+
422
+ const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
423
+ const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
424
+ const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
425
+ const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
426
+ if (ncols < 1024) {
427
+ const dim3 block_dims(256, 1, 1);
428
+ rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
429
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
430
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432
+ add_nchannels_packed, add_nsamples_packed);
433
+ } else {
434
+ const dim3 block_dims(1024, 1, 1);
435
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
436
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
437
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439
+ add_nchannels_packed, add_nsamples_packed);
440
+ }
317
441
  }
318
442
  }
319
443
 
@@ -407,6 +531,154 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
407
531
  rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
408
532
  }
409
533
 
534
+ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
535
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
536
+ float eps = 0.0f;
537
+
538
+ memcpy(&eps, dst->op_params, sizeof(float));
539
+
540
+ const float * src0_d = (const float *) rms_norm_src->data;
541
+ const float * mul_d = nullptr;
542
+ const ggml_tensor * mul_src = nullptr;
543
+
544
+ if (mul_tensor->src[0] == dst) {
545
+ mul_d = (float *) mul_tensor->src[1]->data;
546
+ mul_src = mul_tensor->src[1];
547
+ } else if(mul_tensor->src[1] == dst) {
548
+ mul_d = (float *) mul_tensor->src[0]->data;
549
+ mul_src = mul_tensor->src[0];
550
+ } else {
551
+ GGML_ASSERT(false);
552
+ }
553
+
554
+ float * dst_d = (float *) mul_tensor->data;
555
+ cudaStream_t stream = ctx.stream();
556
+
557
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
558
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
559
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
560
+ GGML_ASSERT(eps >= 0.0f);
561
+
562
+ const int64_t ne00 = rms_norm_src->ne[0];
563
+ const int64_t ne01 = rms_norm_src->ne[1];
564
+ const int64_t ne02 = rms_norm_src->ne[2];
565
+ const int64_t ne03 = rms_norm_src->ne[3];
566
+
567
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
568
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
569
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
570
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
571
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
572
+
573
+ const size_t ts_mul = ggml_type_size(mul_src->type);
574
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
575
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
576
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
577
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
578
+
579
+ const int mul_ncols = mul_src->ne[0];
580
+ const int mul_nrows = mul_src->ne[1];
581
+ const int mul_nchannels = mul_src->ne[2];
582
+ const int mul_nsamples = mul_src->ne[3];
583
+
584
+ rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
585
+ ne00, ne01, ne02, ne03,
586
+ /*s00*/ s01, s02, s03,
587
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
588
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
589
+ /*add_s00*/ 0, 0, 0,
590
+ 0, 0, 0, 0,
591
+ eps, stream);
592
+ }
593
+
594
+ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
595
+ ggml_tensor * dst,
596
+ ggml_tensor * mul_tensor,
597
+ ggml_tensor * add_tensor) {
598
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
599
+ float eps = 0.0f;
600
+
601
+ memcpy(&eps, dst->op_params, sizeof(float));
602
+
603
+ const float * src0_d = (const float *) rms_norm_src->data;
604
+ const float * mul_d = nullptr;
605
+ const ggml_tensor * mul_src = nullptr;
606
+
607
+ if (mul_tensor->src[0] == dst) {
608
+ mul_d = (float *) mul_tensor->src[1]->data;
609
+ mul_src = mul_tensor->src[1];
610
+ } else if (mul_tensor->src[1] == dst) {
611
+ mul_d = (float *) mul_tensor->src[0]->data;
612
+ mul_src = mul_tensor->src[0];
613
+ } else {
614
+ GGML_ASSERT(false);
615
+ }
616
+
617
+ const float * add_d = nullptr;
618
+ const ggml_tensor * add_src = nullptr;
619
+
620
+ if (add_tensor->src[0] == mul_tensor) {
621
+ add_d = (float *) add_tensor->src[1]->data;
622
+ add_src = add_tensor->src[1];
623
+ } else if (add_tensor->src[1] == mul_tensor) {
624
+ add_d = (float *) add_tensor->src[0]->data;
625
+ add_src = add_tensor->src[0];
626
+ } else {
627
+ GGML_ASSERT(false);
628
+ }
629
+
630
+ float * dst_d = (float *) add_tensor->data;
631
+ cudaStream_t stream = ctx.stream();
632
+
633
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
634
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
635
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
636
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
637
+ GGML_ASSERT(eps >= 0.0f);
638
+
639
+ const int64_t ne00 = rms_norm_src->ne[0];
640
+ const int64_t ne01 = rms_norm_src->ne[1];
641
+ const int64_t ne02 = rms_norm_src->ne[2];
642
+ const int64_t ne03 = rms_norm_src->ne[3];
643
+
644
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
645
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
646
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
647
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
648
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
649
+
650
+ const size_t ts_mul = ggml_type_size(mul_src->type);
651
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
652
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
653
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
654
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
655
+
656
+ const int mul_ncols = mul_src->ne[0];
657
+ const int mul_nrows = mul_src->ne[1];
658
+ const int mul_nchannels = mul_src->ne[2];
659
+ const int mul_nsamples = mul_src->ne[3];
660
+
661
+ const size_t ts_add = ggml_type_size(add_src->type);
662
+ GGML_ASSERT(add_src->nb[0] == ts_add);
663
+ const int64_t add_s01 = add_src->nb[1] / ts_add;
664
+ const int64_t add_s02 = add_src->nb[2] / ts_add;
665
+ const int64_t add_s03 = add_src->nb[3] / ts_add;
666
+
667
+ const int add_ncols = add_src->ne[0];
668
+ const int add_nrows = add_src->ne[1];
669
+ const int add_nchannels = add_src->ne[2];
670
+ const int add_nsamples = add_src->ne[3];
671
+
672
+ rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
673
+ ne00,ne01, ne02, ne03,
674
+ /*s00*/ s01, s02, s03,
675
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
676
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
677
+ /*add_s00*/ add_s01, add_s02, add_s03,
678
+ add_ncols, add_nrows, add_nchannels, add_nsamples,
679
+ eps, stream);
680
+ }
681
+
410
682
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
411
683
  const ggml_tensor * grad = dst->src[0]; // gradients
412
684
  const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
@@ -6,6 +6,13 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
6
6
 
7
7
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
8
 
9
+ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
10
+
11
+ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
12
+ ggml_tensor * dst,
13
+ ggml_tensor * mul_tensor,
14
+ ggml_tensor * add_tensor);
15
+
9
16
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10
17
 
11
18
  void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -0,0 +1,49 @@
1
+ #include "ggml-impl.h"
2
+ #include "opt-step-sgd.cuh"
3
+
4
+ #include <cstdint>
5
+
6
+ static __global__ void opt_step_sgd_f32(
7
+ float * __restrict__ x, const float * __restrict__ g,
8
+ const float * __restrict__ pars, const int64_t k) {
9
+
10
+ const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11
+
12
+ if (i >= k) {
13
+ return;
14
+ }
15
+ x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
16
+ }
17
+
18
+ static void opt_step_sgd_f32_cuda(
19
+ float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
20
+
21
+ const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
22
+ const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
23
+ opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
24
+ }
25
+
26
+ void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
27
+ const ggml_tensor * src0 = dst->src[0];
28
+ const ggml_tensor * src0_grad = dst->src[1];
29
+ const ggml_tensor * params = dst->src[2];
30
+
31
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
32
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
33
+ GGML_ASSERT(params->type == GGML_TYPE_F32);
34
+ GGML_ASSERT(ggml_is_contiguous(src0));
35
+ GGML_ASSERT(ggml_is_contiguous(src0_grad));
36
+ GGML_ASSERT(ggml_is_contiguous(params));
37
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
38
+ GGML_ASSERT(ggml_nelements(params) == 2);
39
+
40
+ float * src0_d = (float *) src0->data;
41
+ const float * src0_grad_d = (const float *) src0_grad->data;
42
+ const float * params_d = (const float *) params->data;
43
+
44
+ cudaStream_t stream = ctx.stream();
45
+
46
+ const int64_t ne = ggml_nelements(src0);
47
+
48
+ opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
49
+ }
@@ -0,0 +1,5 @@
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);