whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -104,10 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
104
104
  }
105
105
  }
106
106
 
107
- template <int block_size>
108
- static __global__ void rms_norm_f32(
109
- const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110
- const int64_t stride_sample, const float eps) {
107
+ template <int block_size, bool do_multiply = false, bool do_add = false>
108
+ static __global__ void rms_norm_f32(const float * x,
109
+ float * dst,
110
+ const int ncols,
111
+ const int64_t stride_row,
112
+ const int64_t stride_channel,
113
+ const int64_t stride_sample,
114
+ const float eps,
115
+ const float * mul = nullptr,
116
+ const int64_t mul_stride_row = 0,
117
+ const int64_t mul_stride_channel = 0,
118
+ const int64_t mul_stride_sample = 0,
119
+ const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
120
+ const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
121
+ const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
122
+ const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
123
+ const float * add = nullptr,
124
+ const int64_t add_stride_row = 0,
125
+ const int64_t add_stride_channel = 0,
126
+ const int64_t add_stride_sample = 0,
127
+ const uint3 add_ncols_packed = make_uint3(0, 0, 0),
128
+ const uint3 add_nrows_packed = make_uint3(0, 0, 0),
129
+ const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
130
+ const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
111
131
  const int nrows = gridDim.x;
112
132
  const int nchannels = gridDim.y;
113
133
 
@@ -116,9 +136,25 @@ static __global__ void rms_norm_f32(
116
136
  const int sample = blockIdx.z;
117
137
  const int tid = threadIdx.x;
118
138
 
139
+ static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
140
+
119
141
  x += sample*stride_sample + channel*stride_channel + row*stride_row;
120
142
  dst += ((sample*nchannels + channel)*nrows + row)*ncols;
121
143
 
144
+ if constexpr (do_multiply) {
145
+ const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
146
+ const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
147
+ const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
148
+ mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
149
+ }
150
+
151
+ if constexpr (do_add) {
152
+ const int add_row = fastmodulo(row, add_nrows_packed);
153
+ const int add_channel = fastmodulo(channel, add_nchannels_packed);
154
+ const int add_sample = fastmodulo(sample, add_nsamples_packed);
155
+ add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
156
+ }
157
+
122
158
  float tmp = 0.0f; // partial sum for thread in warp
123
159
 
124
160
  for (int col = tid; col < ncols; col += block_size) {
@@ -129,15 +165,18 @@ static __global__ void rms_norm_f32(
129
165
  // sum up partial sums
130
166
  tmp = warp_reduce_sum(tmp);
131
167
  if constexpr (block_size > WARP_SIZE) {
132
- static_assert(block_size == 1024, "unexpected block_size");
168
+ static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
133
169
  __shared__ float s_sum[32];
134
- const int warp_id = threadIdx.x / WARP_SIZE;
135
- const int lane_id = threadIdx.x % WARP_SIZE;
170
+ const int warp_id = tid / WARP_SIZE;
171
+ const int lane_id = tid % WARP_SIZE;
136
172
  if (lane_id == 0) {
137
173
  s_sum[warp_id] = tmp;
138
174
  }
139
175
  __syncthreads();
140
- tmp = s_sum[lane_id];
176
+ tmp = 0.0f;
177
+ if (lane_id < (block_size / WARP_SIZE)) {
178
+ tmp = s_sum[lane_id];
179
+ }
141
180
  tmp = warp_reduce_sum(tmp);
142
181
  }
143
182
 
@@ -145,7 +184,16 @@ static __global__ void rms_norm_f32(
145
184
  const float scale = rsqrtf(mean + eps);
146
185
 
147
186
  for (int col = tid; col < ncols; col += block_size) {
148
- dst[col] = scale * x[col];
187
+ if constexpr (do_multiply && do_add) {
188
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
189
+ const int add_col = fastmodulo(col, add_ncols_packed);
190
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
191
+ } else if constexpr (do_multiply) {
192
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
193
+ dst[col] = scale * x[col] * mul[mul_col];
194
+ } else {
195
+ dst[col] = scale * x[col];
196
+ }
149
197
  }
150
198
  }
151
199
 
@@ -309,11 +357,87 @@ static void rms_norm_f32_cuda(
309
357
  const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
310
358
  const dim3 blocks_num(nrows, nchannels, nsamples);
311
359
  if (ncols < 1024) {
312
- const dim3 block_dims(WARP_SIZE, 1, 1);
313
- rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
360
+ const dim3 block_dims(256, 1, 1);
361
+ rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
314
362
  } else {
315
363
  const dim3 block_dims(1024, 1, 1);
316
- rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
364
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
365
+ }
366
+ }
367
+
368
+ static void rms_norm_mul_f32_cuda(const float * x,
369
+ const float * mul,
370
+ const float * add,
371
+ float * dst,
372
+ const int ncols,
373
+ const int nrows,
374
+ const int nchannels,
375
+ const int nsamples,
376
+ const int64_t stride_row,
377
+ const int64_t stride_channel,
378
+ const int64_t stride_sample,
379
+ const int64_t mul_stride_row,
380
+ const int64_t mul_stride_channel,
381
+ const int64_t mul_stride_sample,
382
+ const uint32_t mul_ncols,
383
+ const uint32_t mul_nrows,
384
+ const uint32_t mul_nchannels,
385
+ const uint32_t mul_nsamples,
386
+ const int64_t add_stride_row,
387
+ const int64_t add_stride_channel,
388
+ const int64_t add_stride_sample,
389
+ const uint32_t add_ncols,
390
+ const uint32_t add_nrows,
391
+ const uint32_t add_nchannels,
392
+ const uint32_t add_nsamples,
393
+ const float eps,
394
+ cudaStream_t stream) {
395
+ const dim3 blocks_num(nrows, nchannels, nsamples);
396
+ if (mul == nullptr) {
397
+ rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
398
+ return;
399
+ }
400
+ if (add == nullptr) {
401
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
402
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
403
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
404
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
405
+ if (ncols < 1024) {
406
+ const dim3 block_dims(256, 1, 1);
407
+ rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
408
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
409
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
410
+ } else {
411
+ const dim3 block_dims(1024, 1, 1);
412
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
413
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
414
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
415
+ }
416
+ } else {
417
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
418
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
419
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
420
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
421
+
422
+ const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
423
+ const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
424
+ const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
425
+ const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
426
+ if (ncols < 1024) {
427
+ const dim3 block_dims(256, 1, 1);
428
+ rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
429
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
430
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432
+ add_nchannels_packed, add_nsamples_packed);
433
+ } else {
434
+ const dim3 block_dims(1024, 1, 1);
435
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
436
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
437
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439
+ add_nchannels_packed, add_nsamples_packed);
440
+ }
317
441
  }
318
442
  }
319
443
 
@@ -407,6 +531,154 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
407
531
  rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
408
532
  }
409
533
 
534
+ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
535
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
536
+ float eps = 0.0f;
537
+
538
+ memcpy(&eps, dst->op_params, sizeof(float));
539
+
540
+ const float * src0_d = (const float *) rms_norm_src->data;
541
+ const float * mul_d = nullptr;
542
+ const ggml_tensor * mul_src = nullptr;
543
+
544
+ if (mul_tensor->src[0] == dst) {
545
+ mul_d = (float *) mul_tensor->src[1]->data;
546
+ mul_src = mul_tensor->src[1];
547
+ } else if(mul_tensor->src[1] == dst) {
548
+ mul_d = (float *) mul_tensor->src[0]->data;
549
+ mul_src = mul_tensor->src[0];
550
+ } else {
551
+ GGML_ASSERT(false);
552
+ }
553
+
554
+ float * dst_d = (float *) mul_tensor->data;
555
+ cudaStream_t stream = ctx.stream();
556
+
557
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
558
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
559
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
560
+ GGML_ASSERT(eps >= 0.0f);
561
+
562
+ const int64_t ne00 = rms_norm_src->ne[0];
563
+ const int64_t ne01 = rms_norm_src->ne[1];
564
+ const int64_t ne02 = rms_norm_src->ne[2];
565
+ const int64_t ne03 = rms_norm_src->ne[3];
566
+
567
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
568
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
569
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
570
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
571
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
572
+
573
+ const size_t ts_mul = ggml_type_size(mul_src->type);
574
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
575
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
576
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
577
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
578
+
579
+ const int mul_ncols = mul_src->ne[0];
580
+ const int mul_nrows = mul_src->ne[1];
581
+ const int mul_nchannels = mul_src->ne[2];
582
+ const int mul_nsamples = mul_src->ne[3];
583
+
584
+ rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
585
+ ne00, ne01, ne02, ne03,
586
+ /*s00*/ s01, s02, s03,
587
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
588
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
589
+ /*add_s00*/ 0, 0, 0,
590
+ 0, 0, 0, 0,
591
+ eps, stream);
592
+ }
593
+
594
+ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
595
+ ggml_tensor * dst,
596
+ ggml_tensor * mul_tensor,
597
+ ggml_tensor * add_tensor) {
598
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
599
+ float eps = 0.0f;
600
+
601
+ memcpy(&eps, dst->op_params, sizeof(float));
602
+
603
+ const float * src0_d = (const float *) rms_norm_src->data;
604
+ const float * mul_d = nullptr;
605
+ const ggml_tensor * mul_src = nullptr;
606
+
607
+ if (mul_tensor->src[0] == dst) {
608
+ mul_d = (float *) mul_tensor->src[1]->data;
609
+ mul_src = mul_tensor->src[1];
610
+ } else if (mul_tensor->src[1] == dst) {
611
+ mul_d = (float *) mul_tensor->src[0]->data;
612
+ mul_src = mul_tensor->src[0];
613
+ } else {
614
+ GGML_ASSERT(false);
615
+ }
616
+
617
+ const float * add_d = nullptr;
618
+ const ggml_tensor * add_src = nullptr;
619
+
620
+ if (add_tensor->src[0] == mul_tensor) {
621
+ add_d = (float *) add_tensor->src[1]->data;
622
+ add_src = add_tensor->src[1];
623
+ } else if (add_tensor->src[1] == mul_tensor) {
624
+ add_d = (float *) add_tensor->src[0]->data;
625
+ add_src = add_tensor->src[0];
626
+ } else {
627
+ GGML_ASSERT(false);
628
+ }
629
+
630
+ float * dst_d = (float *) add_tensor->data;
631
+ cudaStream_t stream = ctx.stream();
632
+
633
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
634
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
635
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
636
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
637
+ GGML_ASSERT(eps >= 0.0f);
638
+
639
+ const int64_t ne00 = rms_norm_src->ne[0];
640
+ const int64_t ne01 = rms_norm_src->ne[1];
641
+ const int64_t ne02 = rms_norm_src->ne[2];
642
+ const int64_t ne03 = rms_norm_src->ne[3];
643
+
644
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
645
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
646
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
647
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
648
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
649
+
650
+ const size_t ts_mul = ggml_type_size(mul_src->type);
651
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
652
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
653
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
654
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
655
+
656
+ const int mul_ncols = mul_src->ne[0];
657
+ const int mul_nrows = mul_src->ne[1];
658
+ const int mul_nchannels = mul_src->ne[2];
659
+ const int mul_nsamples = mul_src->ne[3];
660
+
661
+ const size_t ts_add = ggml_type_size(add_src->type);
662
+ GGML_ASSERT(add_src->nb[0] == ts_add);
663
+ const int64_t add_s01 = add_src->nb[1] / ts_add;
664
+ const int64_t add_s02 = add_src->nb[2] / ts_add;
665
+ const int64_t add_s03 = add_src->nb[3] / ts_add;
666
+
667
+ const int add_ncols = add_src->ne[0];
668
+ const int add_nrows = add_src->ne[1];
669
+ const int add_nchannels = add_src->ne[2];
670
+ const int add_nsamples = add_src->ne[3];
671
+
672
+ rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
673
+ ne00,ne01, ne02, ne03,
674
+ /*s00*/ s01, s02, s03,
675
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
676
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
677
+ /*add_s00*/ add_s01, add_s02, add_s03,
678
+ add_ncols, add_nrows, add_nchannels, add_nsamples,
679
+ eps, stream);
680
+ }
681
+
410
682
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
411
683
  const ggml_tensor * grad = dst->src[0]; // gradients
412
684
  const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
@@ -6,6 +6,13 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
6
6
 
7
7
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
8
 
9
+ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
10
+
11
+ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
12
+ ggml_tensor * dst,
13
+ ggml_tensor * mul_tensor,
14
+ ggml_tensor * add_tensor);
15
+
9
16
  void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10
17
 
11
18
  void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -0,0 +1,49 @@
1
+ #include "ggml-impl.h"
2
+ #include "opt-step-sgd.cuh"
3
+
4
+ #include <cstdint>
5
+
6
+ static __global__ void opt_step_sgd_f32(
7
+ float * __restrict__ x, const float * __restrict__ g,
8
+ const float * __restrict__ pars, const int64_t k) {
9
+
10
+ const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11
+
12
+ if (i >= k) {
13
+ return;
14
+ }
15
+ x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
16
+ }
17
+
18
+ static void opt_step_sgd_f32_cuda(
19
+ float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
20
+
21
+ const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
22
+ const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
23
+ opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
24
+ }
25
+
26
+ void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
27
+ const ggml_tensor * src0 = dst->src[0];
28
+ const ggml_tensor * src0_grad = dst->src[1];
29
+ const ggml_tensor * params = dst->src[2];
30
+
31
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
32
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
33
+ GGML_ASSERT(params->type == GGML_TYPE_F32);
34
+ GGML_ASSERT(ggml_is_contiguous(src0));
35
+ GGML_ASSERT(ggml_is_contiguous(src0_grad));
36
+ GGML_ASSERT(ggml_is_contiguous(params));
37
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
38
+ GGML_ASSERT(ggml_nelements(params) == 2);
39
+
40
+ float * src0_d = (float *) src0->data;
41
+ const float * src0_grad_d = (const float *) src0_grad->data;
42
+ const float * params_d = (const float *) params->data;
43
+
44
+ cudaStream_t stream = ctx.stream();
45
+
46
+ const int64_t ne = ggml_nelements(src0);
47
+
48
+ opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
49
+ }
@@ -0,0 +1,5 @@
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -1,36 +1,50 @@
1
1
  #include "pad.cuh"
2
2
 
3
- static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
4
- // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
5
- // blockIdx.y: idx of ne1
6
- // blockIDx.x: idx of ne0 / BLOCK_SIZE
7
- int nidx = threadIdx.x + blockIdx.x * blockDim.x;
8
- if (nidx >= ne0) {
3
+ static __global__ void pad_f32(const float * src, float * dst,
4
+ const int lp0, const int rp0, const int lp1, const int rp1,
5
+ const int lp2, const int rp2, const int lp3, const int rp3,
6
+ const int ne0, const int ne1, const int ne2, const int ne3) {
7
+ // blockIdx.z: i3*ne2+i2
8
+ // blockIdx.y: i1
9
+ // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
10
+ // gridDim.y: ne1
11
+ int i0 = threadIdx.x + blockIdx.x * blockDim.x;
12
+ int i1 = blockIdx.y;
13
+ int i2 = blockIdx.z % ne2;
14
+ int i3 = blockIdx.z / ne2;
15
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
9
16
  return;
10
17
  }
11
18
 
12
19
  // operation
13
- int offset_dst =
14
- nidx +
15
- blockIdx.y * ne0 +
16
- blockIdx.z * ne0 * gridDim.y;
17
- if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
18
- int offset_src =
19
- nidx +
20
- blockIdx.y * ne00 +
21
- blockIdx.z * ne00 * ne01;
22
- dst[offset_dst] = x[offset_src];
20
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
21
+ if ((i0 >= lp0 && i0 < ne0 - rp0) &&
22
+ (i1 >= lp1 && i1 < ne1 - rp1) &&
23
+ (i2 >= lp2 && i2 < ne2 - rp2) &&
24
+ (i3 >= lp3 && i3 < ne3 - rp3)) {
25
+ const int64_t i00 = i0 - lp0;
26
+ const int64_t i01 = i1 - lp1;
27
+ const int64_t i02 = i2 - lp2;
28
+ const int64_t i03 = i3 - lp3;
29
+ const int64_t ne02 = ne2 - lp2 - rp2;
30
+ const int64_t ne01 = ne1 - lp1 - rp1;
31
+ const int64_t ne00 = ne0 - lp0 - rp0;
32
+
33
+ const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
34
+
35
+ dst[dst_idx] = src[src_idx];
23
36
  } else {
24
- dst[offset_dst] = 0.0f;
37
+ dst[dst_idx] = 0.0f;
25
38
  }
26
39
  }
27
40
 
28
- static void pad_f32_cuda(const float * x, float * dst,
29
- const int ne00, const int ne01, const int ne02, const int ne03,
41
+ static void pad_f32_cuda(const float * src, float * dst,
42
+ const int lp0, const int rp0, const int lp1, const int rp1,
43
+ const int lp2, const int rp2, const int lp3, const int rp3,
30
44
  const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
31
45
  int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
32
46
  dim3 gridDim(num_blocks, ne1, ne2*ne3);
33
- pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
47
+ pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
34
48
  }
35
49
 
36
50
  void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
41
55
 
42
56
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
43
57
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
44
- GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
58
+ GGML_ASSERT(ggml_is_contiguous(src0));
59
+
60
+ const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
61
+ const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
62
+ const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
63
+ const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
64
+ const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
65
+ const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
66
+ const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
67
+ const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
45
68
 
46
69
  pad_f32_cuda(src0_d, dst_d,
47
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
48
- dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
70
+ lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
71
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
49
72
  }
@@ -0,0 +1,91 @@
1
+ #include "pad_reflect_1d.cuh"
2
+
3
+ static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
4
+ pad_reflect_1d_kernel_f32(
5
+ const void * __restrict__ src0,
6
+ void * __restrict__ dst,
7
+ const int64_t ne0,
8
+ const int64_t ne00,
9
+ const uint3 ne01,
10
+ const int64_t ne02,
11
+ const int64_t ne03,
12
+ const int64_t nb00,
13
+ const int64_t nb01,
14
+ const int64_t nb02,
15
+ const int64_t nb03,
16
+ const int64_t nb0,
17
+ const int64_t nb1,
18
+ const int64_t nb2,
19
+ const int64_t nb3,
20
+ const int p0,
21
+ const int p1) {
22
+ const int64_t i3 = blockIdx.z;
23
+ const int64_t i2 = blockIdx.y;
24
+
25
+ const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
26
+ const int64_t tile1 = div_mod_packed.y; // i1
27
+ const int64_t tile0 = div_mod_packed.x; // nth i0 tile
28
+ const int64_t i1 = tile1;
29
+ const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
30
+
31
+ // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
32
+ if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
33
+ return;
34
+ }
35
+
36
+ const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
37
+ char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
38
+
39
+ const int64_t rel_i0 = i0 - p0; // relative i0 in src0
40
+ int64_t src_idx;
41
+
42
+ if (rel_i0 < 0) {
43
+ // Left padding - reflect
44
+ src_idx = -rel_i0;
45
+ } else if (rel_i0 < ne00) {
46
+ // Middle - copy
47
+ src_idx = rel_i0;
48
+ } else {
49
+ // Right padding - reflect
50
+ src_idx = 2 * ne00 - 2 - rel_i0;
51
+ }
52
+ const float value = *(const float *) (src0_ptr + src_idx * nb00);
53
+ *(float *) (dst_ptr + i0 * nb0) = value;
54
+
55
+ GGML_UNUSED(p1);
56
+ }
57
+
58
+ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
59
+ const ggml_tensor * src0 = dst->src[0];
60
+ cudaStream_t stream = ctx.stream();
61
+
62
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
63
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
64
+
65
+ const int32_t * opts = (const int32_t *) dst->op_params;
66
+ const int p0 = opts[0];
67
+ const int p1 = opts[1];
68
+
69
+ const int64_t ne00 = src0->ne[0];
70
+ const int64_t ne01 = src0->ne[1];
71
+ const uint3 ne01_packed = init_fastdiv_values(ne01);
72
+ const int64_t ne02 = src0->ne[2];
73
+ const int64_t ne03 = src0->ne[3];
74
+
75
+ const int64_t ne0 = dst->ne[0];
76
+
77
+ // sanity: padded length matches
78
+ GGML_ASSERT(ne0 == ne00 + p0 + p1);
79
+
80
+ constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
81
+ const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
82
+ // grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
83
+ // grid.y covers i2: [ne02]
84
+ // grid.z covers i3: [ne03]
85
+ const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
86
+ const dim3 block_dims((unsigned) bx, 1, 1);
87
+
88
+ pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
89
+ src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
90
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
91
+ }
@@ -0,0 +1,5 @@
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);