whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -8,6 +8,7 @@
8
8
  #include "vec.h"
9
9
 
10
10
  #include <float.h>
11
+ #include <algorithm>
11
12
 
12
13
  // ggml_compute_forward_dup
13
14
 
@@ -40,13 +41,15 @@ static void ggml_compute_forward_dup_same_cont(
40
41
  }
41
42
  }
42
43
 
43
- static void ggml_compute_forward_dup_f16(
44
+ template<typename src_t, typename dst_t>
45
+ static void ggml_compute_forward_dup_flt(
44
46
  const ggml_compute_params * params,
45
47
  ggml_tensor * dst) {
46
48
 
47
49
  const ggml_tensor * src0 = dst->src[0];
48
50
 
49
51
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
52
+ GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
50
53
 
51
54
  GGML_TENSOR_UNARY_OP_LOCALS
52
55
 
@@ -61,6 +64,7 @@ static void ggml_compute_forward_dup_f16(
61
64
  const int ir0 = dr * ith;
62
65
  const int ir1 = MIN(ir0 + dr, nr);
63
66
 
67
+ // case: type & row size equal
64
68
  if (src0->type == dst->type &&
65
69
  ne00 == ne0 &&
66
70
  nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
@@ -79,11 +83,11 @@ static void ggml_compute_forward_dup_f16(
79
83
  return;
80
84
  }
81
85
 
82
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
83
-
86
+ // case: dst tensor is contiguous
84
87
  if (ggml_is_contiguous(dst)) {
85
- if (nb00 == sizeof(ggml_fp16_t)) {
86
- if (dst->type == GGML_TYPE_F16) {
88
+ if (nb00 == sizeof(src_t)) {
89
+ if constexpr (std::is_same_v<dst_t, src_t>) {
90
+ // same type
87
91
  size_t id = 0;
88
92
  const size_t rs = ne00 * nb00;
89
93
  char * dst_ptr = (char *) dst->data;
@@ -99,91 +103,46 @@ static void ggml_compute_forward_dup_f16(
99
103
  id += rs * (ne01 - ir1);
100
104
  }
101
105
  }
102
- } else if (dst->type == GGML_TYPE_F32) {
106
+ } else {
107
+ // casting between non-quantized types
103
108
  size_t id = 0;
104
- float * dst_ptr = (float *) dst->data;
109
+ dst_t * dst_ptr = (dst_t *) dst->data;
105
110
 
106
111
  for (int i03 = 0; i03 < ne03; i03++) {
107
112
  for (int i02 = 0; i02 < ne02; i02++) {
108
113
  id += ne00 * ir0;
109
114
  for (int i01 = ir0; i01 < ir1; i01++) {
110
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
115
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
111
116
  for (int i00 = 0; i00 < ne00; i00++) {
112
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
117
+ float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
118
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
113
119
  id++;
114
120
  }
115
121
  }
116
122
  id += ne00 * (ne01 - ir1);
117
123
  }
118
124
  }
119
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
120
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
121
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
122
-
123
- size_t id = 0;
124
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
125
- char * dst_ptr = (char *) dst->data;
126
-
127
- for (int i03 = 0; i03 < ne03; i03++) {
128
- for (int i02 = 0; i02 < ne02; i02++) {
129
- id += rs * ir0;
130
- for (int i01 = ir0; i01 < ir1; i01++) {
131
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
132
-
133
- for (int i00 = 0; i00 < ne00; i00++) {
134
- src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
135
- }
136
-
137
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
138
- id += rs;
139
- }
140
- id += rs * (ne01 - ir1);
141
- }
142
- }
143
- } else {
144
- GGML_ABORT("fatal error"); // TODO: implement
145
125
  }
146
126
  } else {
147
127
  //printf("%s: this is not optimal - fix me\n", __func__);
148
128
 
149
- if (dst->type == GGML_TYPE_F32) {
150
- size_t id = 0;
151
- float * dst_ptr = (float *) dst->data;
152
-
153
- for (int i03 = 0; i03 < ne03; i03++) {
154
- for (int i02 = 0; i02 < ne02; i02++) {
155
- id += ne00 * ir0;
156
- for (int i01 = ir0; i01 < ir1; i01++) {
157
- for (int i00 = 0; i00 < ne00; i00++) {
158
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
159
-
160
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
161
- id++;
162
- }
163
- }
164
- id += ne00 * (ne01 - ir1);
165
- }
166
- }
167
- } else if (dst->type == GGML_TYPE_F16) {
168
- size_t id = 0;
169
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
129
+ size_t id = 0;
130
+ dst_t * dst_ptr = (dst_t *) dst->data;
170
131
 
171
- for (int i03 = 0; i03 < ne03; i03++) {
172
- for (int i02 = 0; i02 < ne02; i02++) {
173
- id += ne00 * ir0;
174
- for (int i01 = ir0; i01 < ir1; i01++) {
175
- for (int i00 = 0; i00 < ne00; i00++) {
176
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
132
+ for (int i03 = 0; i03 < ne03; i03++) {
133
+ for (int i02 = 0; i02 < ne02; i02++) {
134
+ id += ne00 * ir0;
135
+ for (int i01 = ir0; i01 < ir1; i01++) {
136
+ for (int i00 = 0; i00 < ne00; i00++) {
137
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
177
138
 
178
- dst_ptr[id] = *src0_ptr;
179
- id++;
180
- }
139
+ float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
140
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
141
+ id++;
181
142
  }
182
- id += ne00 * (ne01 - ir1);
183
143
  }
144
+ id += ne00 * (ne01 - ir1);
184
145
  }
185
- } else {
186
- GGML_ABORT("fatal error"); // TODO: implement
187
146
  }
188
147
  }
189
148
  return;
@@ -195,7 +154,7 @@ static void ggml_compute_forward_dup_f16(
195
154
  int64_t i12 = 0;
196
155
  int64_t i13 = 0;
197
156
 
198
- if (dst->type == GGML_TYPE_F16) {
157
+ if constexpr (std::is_same_v<dst_t, src_t>) {
199
158
  for (int64_t i03 = 0; i03 < ne03; i03++) {
200
159
  for (int64_t i02 = 0; i02 < ne02; i02++) {
201
160
  i10 += ne00 * ir0;
@@ -216,7 +175,7 @@ static void ggml_compute_forward_dup_f16(
216
175
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
217
176
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
218
177
 
219
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
178
+ memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
220
179
 
221
180
  if (++i10 == ne00) {
222
181
  i10 = 0;
@@ -247,7 +206,8 @@ static void ggml_compute_forward_dup_f16(
247
206
  }
248
207
  }
249
208
  }
250
- } else if (dst->type == GGML_TYPE_F32) {
209
+
210
+ } else {
251
211
  for (int64_t i03 = 0; i03 < ne03; i03++) {
252
212
  for (int64_t i02 = 0; i02 < ne02; i02++) {
253
213
  i10 += ne00 * ir0;
@@ -268,7 +228,8 @@ static void ggml_compute_forward_dup_f16(
268
228
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
269
229
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
270
230
 
271
- *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
231
+ float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
232
+ *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
272
233
 
273
234
  if (++i10 == ne0) {
274
235
  i10 = 0;
@@ -299,18 +260,19 @@ static void ggml_compute_forward_dup_f16(
299
260
  }
300
261
  }
301
262
  }
302
- } else {
303
- GGML_ABORT("fatal error"); // TODO: implement
304
263
  }
305
264
  }
306
265
 
307
- static void ggml_compute_forward_dup_bf16(
266
+
267
+ template<typename src_t>
268
+ static void ggml_compute_forward_dup_to_q(
308
269
  const ggml_compute_params * params,
309
270
  ggml_tensor * dst) {
310
271
 
311
272
  const ggml_tensor * src0 = dst->src[0];
312
273
 
313
274
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
275
+ GGML_ASSERT(!ggml_is_quantized(src0->type));
314
276
 
315
277
  GGML_TENSOR_UNARY_OP_LOCALS
316
278
 
@@ -325,11 +287,73 @@ static void ggml_compute_forward_dup_bf16(
325
287
  const int ir0 = dr * ith;
326
288
  const int ir1 = MIN(ir0 + dr, nr);
327
289
 
290
+ if (ggml_is_contiguous(dst) &&
291
+ nb00 == sizeof(src_t) &&
292
+ ggml_get_type_traits_cpu(dst->type)->from_float) {
293
+ // casting non-quantized types --> intermediate f32 --> quantized
294
+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
295
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
296
+
297
+ size_t id = 0;
298
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
299
+ char * dst_ptr = (char *) dst->data;
300
+
301
+ for (int i03 = 0; i03 < ne03; i03++) {
302
+ for (int i02 = 0; i02 < ne02; i02++) {
303
+ id += rs * ir0;
304
+ for (int i01 = ir0; i01 < ir1; i01++) {
305
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
306
+
307
+ for (int i00 = 0; i00 < ne00; i00++) {
308
+ src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
309
+ }
310
+
311
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
312
+ id += rs;
313
+ }
314
+ id += rs * (ne01 - ir1);
315
+ }
316
+ }
317
+ } else {
318
+ // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
319
+ GGML_ABORT("not implemented");
320
+ }
321
+ }
322
+
323
+ // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
324
+ static void ggml_compute_forward_dup_bytes(
325
+ const ggml_compute_params * params,
326
+ ggml_tensor * dst) {
327
+ const ggml_tensor * src0 = dst->src[0];
328
+
329
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
330
+ GGML_ASSERT(src0->type == dst->type);
331
+
332
+ GGML_TENSOR_UNARY_OP_LOCALS;
333
+
334
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
335
+ ggml_compute_forward_dup_same_cont(params, dst);
336
+ return;
337
+ }
338
+
339
+ const size_t type_size = ggml_type_size(src0->type);
340
+
341
+ const int ith = params->ith; // thread index
342
+ const int nth = params->nth; // number of threads
343
+
344
+ // parallelize by rows
345
+ const int nr = ne01;
346
+ // number of rows per thread
347
+ const int dr = (nr + nth - 1) / nth;
348
+ // row range for this thread
349
+ const int ir0 = dr * ith;
350
+ const int ir1 = MIN(ir0 + dr, nr);
351
+
328
352
  if (src0->type == dst->type &&
329
- ne00 == ne0 &&
330
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
353
+ ggml_are_same_shape(src0, dst) &&
354
+ nb00 == type_size && nb0 == type_size) {
331
355
  // copy by rows
332
- const size_t rs = ne00*nb00;
356
+ const size_t rs = ggml_row_size(src0->type, ne00);
333
357
  for (int64_t i03 = 0; i03 < ne03; i03++) {
334
358
  for (int64_t i02 = 0; i02 < ne02; i02++) {
335
359
  for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -343,765 +367,110 @@ static void ggml_compute_forward_dup_bf16(
343
367
  return;
344
368
  }
345
369
 
346
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
347
-
348
370
  if (ggml_is_contiguous(dst)) {
349
- if (nb00 == sizeof(ggml_bf16_t)) {
350
- if (dst->type == GGML_TYPE_BF16) {
351
- size_t id = 0;
352
- const size_t rs = ne00 * nb00;
353
- char * dst_ptr = (char *) dst->data;
371
+ size_t id = 0;
372
+ char * dst_ptr = (char *) dst->data;
373
+ const size_t rs = ne00 * type_size;
354
374
 
355
- for (int i03 = 0; i03 < ne03; i03++) {
356
- for (int i02 = 0; i02 < ne02; i02++) {
357
- id += rs * ir0;
358
- for (int i01 = ir0; i01 < ir1; i01++) {
359
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
360
- memcpy(dst_ptr + id, src0_ptr, rs);
361
- id += rs;
362
- }
363
- id += rs * (ne01 - ir1);
375
+ if (nb00 == type_size) {
376
+ // src0 is contigous on first dimension, copy by rows
377
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
378
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
379
+ id += rs * ir0;
380
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
381
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
382
+ memcpy(dst_ptr + id, src0_ptr, rs);
383
+ id += rs;
364
384
  }
385
+ id += rs * (ne01 - ir1);
365
386
  }
366
- } else if (dst->type == GGML_TYPE_F16) {
367
- size_t id = 0;
368
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
387
+ }
388
+ } else {
389
+ //printf("%s: this is not optimal - fix me\n", __func__);
369
390
 
370
- for (int i03 = 0; i03 < ne03; i03++) {
371
- for (int i02 = 0; i02 < ne02; i02++) {
372
- id += ne00 * ir0;
373
- for (int i01 = ir0; i01 < ir1; i01++) {
374
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
375
- for (int i00 = 0; i00 < ne00; i00++) {
376
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
377
- id++;
378
- }
379
- }
380
- id += ne00 * (ne01 - ir1);
381
- }
382
- }
383
- } else if (dst->type == GGML_TYPE_F32) {
384
- size_t id = 0;
385
- float * dst_ptr = (float *) dst->data;
391
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
392
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
393
+ id += rs * ir0;
394
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
395
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
396
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
397
+ memcpy(dst_ptr + id, src0_ptr, type_size);
386
398
 
387
- for (int i03 = 0; i03 < ne03; i03++) {
388
- for (int i02 = 0; i02 < ne02; i02++) {
389
- id += ne00 * ir0;
390
- for (int i01 = ir0; i01 < ir1; i01++) {
391
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
392
- for (int i00 = 0; i00 < ne00; i00++) {
393
- dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
394
- id++;
395
- }
399
+ id += type_size;
396
400
  }
397
- id += ne00 * (ne01 - ir1);
398
401
  }
402
+ id += rs * (ne01 - ir1);
399
403
  }
400
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
401
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
402
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
404
+ }
405
+ }
403
406
 
404
- size_t id = 0;
405
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
406
- char * dst_ptr = (char *) dst->data;
407
+ return;
408
+ }
407
409
 
408
- for (int i03 = 0; i03 < ne03; i03++) {
409
- for (int i02 = 0; i02 < ne02; i02++) {
410
- id += rs * ir0;
411
- for (int i01 = ir0; i01 < ir1; i01++) {
412
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
410
+ // dst counters
411
+ int64_t k10 = 0;
412
+ int64_t i11 = 0;
413
+ int64_t i12 = 0;
414
+ int64_t i13 = 0;
413
415
 
414
- for (int i00 = 0; i00 < ne00; i00++) {
415
- src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
416
- }
416
+ // number of blocks in a row
417
+ const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
418
+ const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
417
419
 
418
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
419
- id += rs;
420
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
421
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
422
+ k10 += nk00 * ir0;
423
+ while (k10 >= nk0) {
424
+ k10 -= nk0;
425
+ if (++i11 == ne1) {
426
+ i11 = 0;
427
+ if (++i12 == ne2) {
428
+ i12 = 0;
429
+ if (++i13 == ne3) {
430
+ i13 = 0;
420
431
  }
421
- id += rs * (ne01 - ir1);
422
432
  }
423
433
  }
424
- } else {
425
- GGML_ABORT("fatal error"); // TODO: implement
426
434
  }
427
- } else {
428
- //printf("%s: this is not optimal - fix me\n", __func__);
429
-
430
- if (dst->type == GGML_TYPE_F32) {
431
- size_t id = 0;
432
- float * dst_ptr = (float *) dst->data;
433
-
434
- for (int i03 = 0; i03 < ne03; i03++) {
435
- for (int i02 = 0; i02 < ne02; i02++) {
436
- id += ne00 * ir0;
437
- for (int i01 = ir0; i01 < ir1; i01++) {
438
- for (int i00 = 0; i00 < ne00; i00++) {
439
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
440
-
441
- dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
442
- id++;
443
- }
444
- }
445
- id += ne00 * (ne01 - ir1);
446
- }
447
- }
448
- } else if (dst->type == GGML_TYPE_BF16) {
449
- size_t id = 0;
450
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
435
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
436
+ for (int64_t k00 = 0; k00 < nk00; k00++) {
437
+ const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
438
+ char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
451
439
 
452
- for (int i03 = 0; i03 < ne03; i03++) {
453
- for (int i02 = 0; i02 < ne02; i02++) {
454
- id += ne00 * ir0;
455
- for (int i01 = ir0; i01 < ir1; i01++) {
456
- for (int i00 = 0; i00 < ne00; i00++) {
457
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
440
+ memcpy(dst_ptr, src0_ptr, type_size);
458
441
 
459
- dst_ptr[id] = *src0_ptr;
460
- id++;
442
+ if (++k10 == nk0) {
443
+ k10 = 0;
444
+ if (++i11 == ne1) {
445
+ i11 = 0;
446
+ if (++i12 == ne2) {
447
+ i12 = 0;
448
+ if (++i13 == ne3) {
449
+ i13 = 0;
450
+ }
461
451
  }
462
452
  }
463
- id += ne00 * (ne01 - ir1);
464
453
  }
465
454
  }
466
- } else if (dst->type == GGML_TYPE_F16) {
467
- size_t id = 0;
468
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
469
-
470
- for (int i03 = 0; i03 < ne03; i03++) {
471
- for (int i02 = 0; i02 < ne02; i02++) {
472
- id += ne00 * ir0;
473
- for (int i01 = ir0; i01 < ir1; i01++) {
474
- for (int i00 = 0; i00 < ne00; i00++) {
475
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
476
-
477
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
478
- id++;
479
- }
455
+ }
456
+ k10 += nk00 * (ne01 - ir1);
457
+ while (k10 >= nk0) {
458
+ k10 -= nk0;
459
+ if (++i11 == ne1) {
460
+ i11 = 0;
461
+ if (++i12 == ne2) {
462
+ i12 = 0;
463
+ if (++i13 == ne3) {
464
+ i13 = 0;
480
465
  }
481
- id += ne00 * (ne01 - ir1);
482
466
  }
483
467
  }
484
- } else {
485
- GGML_ABORT("fatal error"); // TODO: implement
486
468
  }
487
469
  }
488
- return;
489
470
  }
471
+ }
490
472
 
491
- // dst counters
492
- int64_t i10 = 0;
493
- int64_t i11 = 0;
494
- int64_t i12 = 0;
495
- int64_t i13 = 0;
496
-
497
- if (dst->type == GGML_TYPE_BF16) {
498
- for (int64_t i03 = 0; i03 < ne03; i03++) {
499
- for (int64_t i02 = 0; i02 < ne02; i02++) {
500
- i10 += ne00 * ir0;
501
- while (i10 >= ne0) {
502
- i10 -= ne0;
503
- if (++i11 == ne1) {
504
- i11 = 0;
505
- if (++i12 == ne2) {
506
- i12 = 0;
507
- if (++i13 == ne3) {
508
- i13 = 0;
509
- }
510
- }
511
- }
512
- }
513
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
514
- for (int64_t i00 = 0; i00 < ne00; i00++) {
515
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
516
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
517
-
518
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
519
-
520
- if (++i10 == ne00) {
521
- i10 = 0;
522
- if (++i11 == ne01) {
523
- i11 = 0;
524
- if (++i12 == ne02) {
525
- i12 = 0;
526
- if (++i13 == ne03) {
527
- i13 = 0;
528
- }
529
- }
530
- }
531
- }
532
- }
533
- }
534
- i10 += ne00 * (ne01 - ir1);
535
- while (i10 >= ne0) {
536
- i10 -= ne0;
537
- if (++i11 == ne1) {
538
- i11 = 0;
539
- if (++i12 == ne2) {
540
- i12 = 0;
541
- if (++i13 == ne3) {
542
- i13 = 0;
543
- }
544
- }
545
- }
546
- }
547
- }
548
- }
549
- } else if (dst->type == GGML_TYPE_F16) {
550
- for (int64_t i03 = 0; i03 < ne03; i03++) {
551
- for (int64_t i02 = 0; i02 < ne02; i02++) {
552
- i10 += ne00 * ir0;
553
- while (i10 >= ne0) {
554
- i10 -= ne0;
555
- if (++i11 == ne1) {
556
- i11 = 0;
557
- if (++i12 == ne2) {
558
- i12 = 0;
559
- if (++i13 == ne3) {
560
- i13 = 0;
561
- }
562
- }
563
- }
564
- }
565
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
566
- for (int64_t i00 = 0; i00 < ne00; i00++) {
567
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
568
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
569
-
570
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
571
-
572
- if (++i10 == ne0) {
573
- i10 = 0;
574
- if (++i11 == ne1) {
575
- i11 = 0;
576
- if (++i12 == ne2) {
577
- i12 = 0;
578
- if (++i13 == ne3) {
579
- i13 = 0;
580
- }
581
- }
582
- }
583
- }
584
- }
585
- }
586
- i10 += ne00 * (ne01 - ir1);
587
- while (i10 >= ne0) {
588
- i10 -= ne0;
589
- if (++i11 == ne1) {
590
- i11 = 0;
591
- if (++i12 == ne2) {
592
- i12 = 0;
593
- if (++i13 == ne3) {
594
- i13 = 0;
595
- }
596
- }
597
- }
598
- }
599
- }
600
- }
601
- } else if (dst->type == GGML_TYPE_F32) {
602
- for (int64_t i03 = 0; i03 < ne03; i03++) {
603
- for (int64_t i02 = 0; i02 < ne02; i02++) {
604
- i10 += ne00 * ir0;
605
- while (i10 >= ne0) {
606
- i10 -= ne0;
607
- if (++i11 == ne1) {
608
- i11 = 0;
609
- if (++i12 == ne2) {
610
- i12 = 0;
611
- if (++i13 == ne3) {
612
- i13 = 0;
613
- }
614
- }
615
- }
616
- }
617
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
618
- for (int64_t i00 = 0; i00 < ne00; i00++) {
619
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
620
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
621
-
622
- *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
623
-
624
- if (++i10 == ne0) {
625
- i10 = 0;
626
- if (++i11 == ne1) {
627
- i11 = 0;
628
- if (++i12 == ne2) {
629
- i12 = 0;
630
- if (++i13 == ne3) {
631
- i13 = 0;
632
- }
633
- }
634
- }
635
- }
636
- }
637
- }
638
- i10 += ne00 * (ne01 - ir1);
639
- while (i10 >= ne0) {
640
- i10 -= ne0;
641
- if (++i11 == ne1) {
642
- i11 = 0;
643
- if (++i12 == ne2) {
644
- i12 = 0;
645
- if (++i13 == ne3) {
646
- i13 = 0;
647
- }
648
- }
649
- }
650
- }
651
- }
652
- }
653
- } else {
654
- GGML_ABORT("fatal error"); // TODO: implement
655
- }
656
- }
657
-
658
- static void ggml_compute_forward_dup_f32(
659
- const ggml_compute_params * params,
660
- ggml_tensor * dst) {
661
-
662
- const ggml_tensor * src0 = dst->src[0];
663
-
664
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
665
-
666
- GGML_TENSOR_UNARY_OP_LOCALS
667
-
668
- const int ith = params->ith; // thread index
669
- const int nth = params->nth; // number of threads
670
-
671
- // parallelize by rows
672
- const int nr = ne01;
673
- // number of rows per thread
674
- const int dr = (nr + nth - 1) / nth;
675
- // row range for this thread
676
- const int ir0 = dr * ith;
677
- const int ir1 = MIN(ir0 + dr, nr);
678
-
679
- if (src0->type == dst->type &&
680
- ne00 == ne0 &&
681
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
682
- // copy by rows
683
- const size_t rs = ne00*nb00;
684
- for (int64_t i03 = 0; i03 < ne03; i03++) {
685
- for (int64_t i02 = 0; i02 < ne02; i02++) {
686
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
687
- memcpy(
688
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
689
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
690
- rs);
691
- }
692
- }
693
- }
694
- return;
695
- }
696
-
697
- if (ggml_is_contiguous(dst)) {
698
- // TODO: simplify
699
- if (nb00 == sizeof(float)) {
700
- if (ggml_get_type_traits_cpu(dst->type)->from_float) {
701
- ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
702
-
703
- size_t id = 0;
704
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
705
- char * dst_ptr = (char *) dst->data;
706
-
707
- for (int i03 = 0; i03 < ne03; i03++) {
708
- for (int i02 = 0; i02 < ne02; i02++) {
709
- id += rs * ir0;
710
- for (int i01 = ir0; i01 < ir1; i01++) {
711
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
712
- from_float(src0_ptr, dst_ptr + id, ne00);
713
- id += rs;
714
- }
715
- id += rs * (ne01 - ir1);
716
- }
717
- }
718
- } else {
719
- GGML_ABORT("fatal error"); // TODO: implement
720
- }
721
- } else {
722
- //printf("%s: this is not optimal - fix me\n", __func__);
723
-
724
- if (dst->type == GGML_TYPE_F32) {
725
- size_t id = 0;
726
- float * dst_ptr = (float *) dst->data;
727
-
728
- for (int i03 = 0; i03 < ne03; i03++) {
729
- for (int i02 = 0; i02 < ne02; i02++) {
730
- id += ne00 * ir0;
731
- for (int i01 = ir0; i01 < ir1; i01++) {
732
- for (int i00 = 0; i00 < ne00; i00++) {
733
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
734
-
735
- dst_ptr[id] = *src0_ptr;
736
- id++;
737
- }
738
- }
739
- id += ne00 * (ne01 - ir1);
740
- }
741
- }
742
- } else if (dst->type == GGML_TYPE_F16) {
743
- size_t id = 0;
744
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
745
-
746
- for (int i03 = 0; i03 < ne03; i03++) {
747
- for (int i02 = 0; i02 < ne02; i02++) {
748
- id += ne00 * ir0;
749
- for (int i01 = ir0; i01 < ir1; i01++) {
750
- for (int i00 = 0; i00 < ne00; i00++) {
751
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
752
-
753
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
754
- id++;
755
- }
756
- }
757
- id += ne00 * (ne01 - ir1);
758
- }
759
- }
760
- } else if (dst->type == GGML_TYPE_BF16) {
761
- size_t id = 0;
762
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
763
-
764
- for (int i03 = 0; i03 < ne03; i03++) {
765
- for (int i02 = 0; i02 < ne02; i02++) {
766
- id += ne00 * ir0;
767
- for (int i01 = ir0; i01 < ir1; i01++) {
768
- for (int i00 = 0; i00 < ne00; i00++) {
769
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
770
-
771
- dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
772
- id++;
773
- }
774
- }
775
- id += ne00 * (ne01 - ir1);
776
- }
777
- }
778
- } else {
779
- GGML_ABORT("fatal error"); // TODO: implement
780
- }
781
- }
782
-
783
- return;
784
- }
785
-
786
- // dst counters
787
-
788
- int64_t i10 = 0;
789
- int64_t i11 = 0;
790
- int64_t i12 = 0;
791
- int64_t i13 = 0;
792
-
793
- if (dst->type == GGML_TYPE_F32) {
794
- for (int64_t i03 = 0; i03 < ne03; i03++) {
795
- for (int64_t i02 = 0; i02 < ne02; i02++) {
796
- i10 += ne00 * ir0;
797
- while (i10 >= ne0) {
798
- i10 -= ne0;
799
- if (++i11 == ne1) {
800
- i11 = 0;
801
- if (++i12 == ne2) {
802
- i12 = 0;
803
- if (++i13 == ne3) {
804
- i13 = 0;
805
- }
806
- }
807
- }
808
- }
809
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
810
- for (int64_t i00 = 0; i00 < ne00; i00++) {
811
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
812
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
813
-
814
- memcpy(dst_ptr, src0_ptr, sizeof(float));
815
-
816
- if (++i10 == ne0) {
817
- i10 = 0;
818
- if (++i11 == ne1) {
819
- i11 = 0;
820
- if (++i12 == ne2) {
821
- i12 = 0;
822
- if (++i13 == ne3) {
823
- i13 = 0;
824
- }
825
- }
826
- }
827
- }
828
- }
829
- }
830
- i10 += ne00 * (ne01 - ir1);
831
- while (i10 >= ne0) {
832
- i10 -= ne0;
833
- if (++i11 == ne1) {
834
- i11 = 0;
835
- if (++i12 == ne2) {
836
- i12 = 0;
837
- if (++i13 == ne3) {
838
- i13 = 0;
839
- }
840
- }
841
- }
842
- }
843
- }
844
- }
845
- } else if (dst->type == GGML_TYPE_F16) {
846
- for (int64_t i03 = 0; i03 < ne03; i03++) {
847
- for (int64_t i02 = 0; i02 < ne02; i02++) {
848
- i10 += ne00 * ir0;
849
- while (i10 >= ne0) {
850
- i10 -= ne0;
851
- if (++i11 == ne1) {
852
- i11 = 0;
853
- if (++i12 == ne2) {
854
- i12 = 0;
855
- if (++i13 == ne3) {
856
- i13 = 0;
857
- }
858
- }
859
- }
860
- }
861
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
862
- for (int64_t i00 = 0; i00 < ne00; i00++) {
863
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
864
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
865
-
866
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
867
-
868
- if (++i10 == ne0) {
869
- i10 = 0;
870
- if (++i11 == ne1) {
871
- i11 = 0;
872
- if (++i12 == ne2) {
873
- i12 = 0;
874
- if (++i13 == ne3) {
875
- i13 = 0;
876
- }
877
- }
878
- }
879
- }
880
- }
881
- }
882
- i10 += ne00 * (ne01 - ir1);
883
- while (i10 >= ne0) {
884
- i10 -= ne0;
885
- if (++i11 == ne1) {
886
- i11 = 0;
887
- if (++i12 == ne2) {
888
- i12 = 0;
889
- if (++i13 == ne3) {
890
- i13 = 0;
891
- }
892
- }
893
- }
894
- }
895
- }
896
- }
897
- } else if (dst->type == GGML_TYPE_BF16) {
898
- for (int64_t i03 = 0; i03 < ne03; i03++) {
899
- for (int64_t i02 = 0; i02 < ne02; i02++) {
900
- i10 += ne00 * ir0;
901
- while (i10 >= ne0) {
902
- i10 -= ne0;
903
- if (++i11 == ne1) {
904
- i11 = 0;
905
- if (++i12 == ne2) {
906
- i12 = 0;
907
- if (++i13 == ne3) {
908
- i13 = 0;
909
- }
910
- }
911
- }
912
- }
913
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
914
- for (int64_t i00 = 0; i00 < ne00; i00++) {
915
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
916
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
917
-
918
- *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
919
-
920
- if (++i10 == ne0) {
921
- i10 = 0;
922
- if (++i11 == ne1) {
923
- i11 = 0;
924
- if (++i12 == ne2) {
925
- i12 = 0;
926
- if (++i13 == ne3) {
927
- i13 = 0;
928
- }
929
- }
930
- }
931
- }
932
- }
933
- }
934
- i10 += ne00 * (ne01 - ir1);
935
- while (i10 >= ne0) {
936
- i10 -= ne0;
937
- if (++i11 == ne1) {
938
- i11 = 0;
939
- if (++i12 == ne2) {
940
- i12 = 0;
941
- if (++i13 == ne3) {
942
- i13 = 0;
943
- }
944
- }
945
- }
946
- }
947
- }
948
- }
949
- } else {
950
- GGML_ABORT("fatal error"); // TODO: implement
951
- }
952
- }
953
-
954
- // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
955
- static void ggml_compute_forward_dup_bytes(
956
- const ggml_compute_params * params,
957
- ggml_tensor * dst) {
958
- const ggml_tensor * src0 = dst->src[0];
959
-
960
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
961
- GGML_ASSERT(src0->type == dst->type);
962
-
963
- GGML_TENSOR_UNARY_OP_LOCALS;
964
-
965
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
966
- ggml_compute_forward_dup_same_cont(params, dst);
967
- return;
968
- }
969
-
970
- const size_t type_size = ggml_type_size(src0->type);
971
-
972
- const int ith = params->ith; // thread index
973
- const int nth = params->nth; // number of threads
974
-
975
- // parallelize by rows
976
- const int nr = ne01;
977
- // number of rows per thread
978
- const int dr = (nr + nth - 1) / nth;
979
- // row range for this thread
980
- const int ir0 = dr * ith;
981
- const int ir1 = MIN(ir0 + dr, nr);
982
-
983
- if (src0->type == dst->type &&
984
- ggml_are_same_shape(src0, dst) &&
985
- nb00 == type_size && nb0 == type_size) {
986
- // copy by rows
987
- const size_t rs = ggml_row_size(src0->type, ne00);
988
- for (int64_t i03 = 0; i03 < ne03; i03++) {
989
- for (int64_t i02 = 0; i02 < ne02; i02++) {
990
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
991
- memcpy(
992
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
993
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
994
- rs);
995
- }
996
- }
997
- }
998
- return;
999
- }
1000
-
1001
- if (ggml_is_contiguous(dst)) {
1002
- size_t id = 0;
1003
- char * dst_ptr = (char *) dst->data;
1004
- const size_t rs = ne00 * type_size;
1005
-
1006
- if (nb00 == type_size) {
1007
- // src0 is contigous on first dimension, copy by rows
1008
- for (int64_t i03 = 0; i03 < ne03; i03++) {
1009
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1010
- id += rs * ir0;
1011
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
1012
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
1013
- memcpy(dst_ptr + id, src0_ptr, rs);
1014
- id += rs;
1015
- }
1016
- id += rs * (ne01 - ir1);
1017
- }
1018
- }
1019
- } else {
1020
- //printf("%s: this is not optimal - fix me\n", __func__);
1021
-
1022
- for (int64_t i03 = 0; i03 < ne03; i03++) {
1023
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1024
- id += rs * ir0;
1025
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
1026
- for (int64_t i00 = 0; i00 < ne00; i00++) {
1027
- const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
1028
- memcpy(dst_ptr + id, src0_ptr, type_size);
1029
-
1030
- id += type_size;
1031
- }
1032
- }
1033
- id += rs * (ne01 - ir1);
1034
- }
1035
- }
1036
- }
1037
-
1038
- return;
1039
- }
1040
-
1041
- // dst counters
1042
- int64_t k10 = 0;
1043
- int64_t i11 = 0;
1044
- int64_t i12 = 0;
1045
- int64_t i13 = 0;
1046
-
1047
- // number of blocks in a row
1048
- const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
1049
- const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
1050
-
1051
- for (int64_t i03 = 0; i03 < ne03; i03++) {
1052
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1053
- k10 += nk00 * ir0;
1054
- while (k10 >= nk0) {
1055
- k10 -= nk0;
1056
- if (++i11 == ne1) {
1057
- i11 = 0;
1058
- if (++i12 == ne2) {
1059
- i12 = 0;
1060
- if (++i13 == ne3) {
1061
- i13 = 0;
1062
- }
1063
- }
1064
- }
1065
- }
1066
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
1067
- for (int64_t k00 = 0; k00 < nk00; k00++) {
1068
- const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
1069
- char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
1070
-
1071
- memcpy(dst_ptr, src0_ptr, type_size);
1072
-
1073
- if (++k10 == nk0) {
1074
- k10 = 0;
1075
- if (++i11 == ne1) {
1076
- i11 = 0;
1077
- if (++i12 == ne2) {
1078
- i12 = 0;
1079
- if (++i13 == ne3) {
1080
- i13 = 0;
1081
- }
1082
- }
1083
- }
1084
- }
1085
- }
1086
- }
1087
- k10 += nk00 * (ne01 - ir1);
1088
- while (k10 >= nk0) {
1089
- k10 -= nk0;
1090
- if (++i11 == ne1) {
1091
- i11 = 0;
1092
- if (++i12 == ne2) {
1093
- i12 = 0;
1094
- if (++i13 == ne3) {
1095
- i13 = 0;
1096
- }
1097
- }
1098
- }
1099
- }
1100
- }
1101
- }
1102
- }
1103
-
1104
- static void ggml_compute_forward_dup_q(
473
+ static void ggml_compute_forward_dup_from_q(
1105
474
  const ggml_compute_params * params,
1106
475
  ggml_tensor * dst) {
1107
476
 
@@ -1166,20 +535,35 @@ void ggml_compute_forward_dup(
1166
535
  switch (src0->type) {
1167
536
  case GGML_TYPE_F16:
1168
537
  {
1169
- ggml_compute_forward_dup_f16(params, dst);
538
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
539
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
540
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
541
+ else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
1170
542
  } break;
1171
543
  case GGML_TYPE_BF16:
1172
544
  {
1173
- ggml_compute_forward_dup_bf16(params, dst);
545
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
546
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
547
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
548
+ else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
1174
549
  } break;
1175
550
  case GGML_TYPE_F32:
1176
551
  {
1177
- ggml_compute_forward_dup_f32(params, dst);
552
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
553
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
554
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
555
+ else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
556
+ else ggml_compute_forward_dup_to_q<float>(params, dst);
557
+ } break;
558
+ case GGML_TYPE_I32:
559
+ {
560
+ if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
561
+ else GGML_ABORT("not implemented");
1178
562
  } break;
1179
563
  default:
1180
564
  {
1181
565
  if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
1182
- ggml_compute_forward_dup_q(params, dst);
566
+ ggml_compute_forward_dup_from_q(params, dst);
1183
567
  break;
1184
568
  }
1185
569
  GGML_ABORT("fatal error");
@@ -1252,20 +636,118 @@ static void ggml_compute_forward_add_q_f32(
1252
636
 
1253
637
  assert(ne00 % 32 == 0);
1254
638
 
1255
- // unquantize row from src0 to temp buffer
1256
- dequantize_row_q(src0_row, wdata, ne00);
1257
- // add src1
1258
- ggml_vec_acc_f32(ne00, wdata, src1_row);
1259
- // quantize row to dst
1260
- if (quantize_row_q != NULL) {
1261
- quantize_row_q(wdata, dst_row, ne00);
1262
- } else {
1263
- memcpy(dst_row, wdata, ne0*nb0);
1264
- }
639
+ // unquantize row from src0 to temp buffer
640
+ dequantize_row_q(src0_row, wdata, ne00);
641
+ // add src1
642
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
643
+ // quantize row to dst
644
+ if (quantize_row_q != NULL) {
645
+ quantize_row_q(wdata, dst_row, ne00);
646
+ } else {
647
+ memcpy(dst_row, wdata, ne0*nb0);
648
+ }
649
+ }
650
+ }
651
+
652
+ void ggml_compute_forward_add(
653
+ const ggml_compute_params * params,
654
+ ggml_tensor * dst) {
655
+
656
+ const ggml_tensor * src0 = dst->src[0];
657
+
658
+ switch (src0->type) {
659
+ case GGML_TYPE_F32:
660
+ case GGML_TYPE_F16:
661
+ case GGML_TYPE_BF16:
662
+ {
663
+ ggml_compute_forward_add_non_quantized(params, dst);
664
+ } break;
665
+ case GGML_TYPE_Q4_0:
666
+ case GGML_TYPE_Q4_1:
667
+ case GGML_TYPE_Q5_0:
668
+ case GGML_TYPE_Q5_1:
669
+ case GGML_TYPE_Q8_0:
670
+ case GGML_TYPE_MXFP4:
671
+ case GGML_TYPE_Q2_K:
672
+ case GGML_TYPE_Q3_K:
673
+ case GGML_TYPE_Q4_K:
674
+ case GGML_TYPE_Q5_K:
675
+ case GGML_TYPE_Q6_K:
676
+ case GGML_TYPE_TQ1_0:
677
+ case GGML_TYPE_TQ2_0:
678
+ case GGML_TYPE_IQ2_XXS:
679
+ case GGML_TYPE_IQ2_XS:
680
+ case GGML_TYPE_IQ3_XXS:
681
+ case GGML_TYPE_IQ1_S:
682
+ case GGML_TYPE_IQ1_M:
683
+ case GGML_TYPE_IQ4_NL:
684
+ case GGML_TYPE_IQ4_XS:
685
+ case GGML_TYPE_IQ3_S:
686
+ case GGML_TYPE_IQ2_S:
687
+ {
688
+ ggml_compute_forward_add_q_f32(params, dst);
689
+ } break;
690
+ default:
691
+ {
692
+ GGML_ABORT("fatal error");
693
+ }
694
+ }
695
+ }
696
+
697
+ // ggml_compute_forward_add_id
698
+
699
+ static void ggml_compute_forward_add_id_f32(
700
+ const ggml_compute_params * params,
701
+ ggml_tensor * dst) {
702
+
703
+ const ggml_tensor * src0 = dst->src[0];
704
+ const ggml_tensor * src1 = dst->src[1];
705
+ const ggml_tensor * src2 = dst->src[2];
706
+
707
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
708
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
709
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
710
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
711
+
712
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
713
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
714
+
715
+ const int ith = params->ith;
716
+ const int nth = params->nth;
717
+
718
+ const int nr = ggml_nrows(src0);
719
+
720
+ GGML_TENSOR_TERNARY_OP_LOCALS
721
+
722
+ GGML_ASSERT( nb0 == sizeof(float));
723
+ GGML_ASSERT(nb10 == sizeof(float));
724
+
725
+ // rows per thread
726
+ const int dr = (nr + nth - 1)/nth;
727
+
728
+ // row range for this thread
729
+ const int ir0 = dr*ith;
730
+ const int ir1 = MIN(ir0 + dr, nr);
731
+
732
+ for (int ir = ir0; ir < ir1; ++ir) {
733
+ // src0 indices
734
+ const int i3 = ir/(ne2*ne1);
735
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
736
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
737
+
738
+ // src1 indices
739
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
740
+
741
+ GGML_ASSERT(i11 >= 0 && i11 < ne11);
742
+
743
+ ggml_vec_add_f32(ne0,
744
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
745
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
746
+ (float *) ((char *) src1->data + i11*nb11));
1265
747
  }
1266
748
  }
1267
749
 
1268
- void ggml_compute_forward_add(
750
+ void ggml_compute_forward_add_id(
1269
751
  const ggml_compute_params * params,
1270
752
  ggml_tensor * dst) {
1271
753
 
@@ -1273,38 +755,12 @@ void ggml_compute_forward_add(
1273
755
 
1274
756
  switch (src0->type) {
1275
757
  case GGML_TYPE_F32:
1276
- case GGML_TYPE_F16:
1277
- case GGML_TYPE_BF16:
1278
- {
1279
- ggml_compute_forward_add_non_quantized(params, dst);
1280
- } break;
1281
- case GGML_TYPE_Q4_0:
1282
- case GGML_TYPE_Q4_1:
1283
- case GGML_TYPE_Q5_0:
1284
- case GGML_TYPE_Q5_1:
1285
- case GGML_TYPE_Q8_0:
1286
- case GGML_TYPE_Q2_K:
1287
- case GGML_TYPE_Q3_K:
1288
- case GGML_TYPE_Q4_K:
1289
- case GGML_TYPE_Q5_K:
1290
- case GGML_TYPE_Q6_K:
1291
- case GGML_TYPE_TQ1_0:
1292
- case GGML_TYPE_TQ2_0:
1293
- case GGML_TYPE_IQ2_XXS:
1294
- case GGML_TYPE_IQ2_XS:
1295
- case GGML_TYPE_IQ3_XXS:
1296
- case GGML_TYPE_IQ1_S:
1297
- case GGML_TYPE_IQ1_M:
1298
- case GGML_TYPE_IQ4_NL:
1299
- case GGML_TYPE_IQ4_XS:
1300
- case GGML_TYPE_IQ3_S:
1301
- case GGML_TYPE_IQ2_S:
1302
758
  {
1303
- ggml_compute_forward_add_q_f32(params, dst);
759
+ ggml_compute_forward_add_id_f32(params, dst);
1304
760
  } break;
1305
761
  default:
1306
762
  {
1307
- GGML_ABORT("fatal error");
763
+ GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
1308
764
  }
1309
765
  }
1310
766
  }
@@ -1660,6 +1116,7 @@ void ggml_compute_forward_add1(
1660
1116
  case GGML_TYPE_Q5_1:
1661
1117
  case GGML_TYPE_Q8_0:
1662
1118
  case GGML_TYPE_Q8_1:
1119
+ case GGML_TYPE_MXFP4:
1663
1120
  case GGML_TYPE_Q2_K:
1664
1121
  case GGML_TYPE_Q3_K:
1665
1122
  case GGML_TYPE_Q4_K:
@@ -1787,6 +1244,7 @@ void ggml_compute_forward_acc(
1787
1244
  case GGML_TYPE_Q5_1:
1788
1245
  case GGML_TYPE_Q8_0:
1789
1246
  case GGML_TYPE_Q8_1:
1247
+ case GGML_TYPE_MXFP4:
1790
1248
  case GGML_TYPE_Q2_K:
1791
1249
  case GGML_TYPE_Q3_K:
1792
1250
  case GGML_TYPE_Q4_K:
@@ -3009,50 +2467,304 @@ static void ggml_compute_forward_leaky_relu_f32(
3009
2467
  const int n = ggml_nrows(src0);
3010
2468
  const int nc = src0->ne[0];
3011
2469
 
3012
- float negative_slope;
3013
- memcpy(&negative_slope, dst->op_params, sizeof(float));
2470
+ float negative_slope;
2471
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
2472
+
2473
+ assert(dst->nb[0] == sizeof(float));
2474
+ assert(src0->nb[0] == sizeof(float));
2475
+
2476
+ for (int i = 0; i < n; i++) {
2477
+ ggml_vec_leaky_relu_f32(nc,
2478
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
2479
+ (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2480
+ }
2481
+ }
2482
+
2483
+ static void ggml_compute_forward_leaky_relu_f16(
2484
+ const ggml_compute_params * params,
2485
+ ggml_tensor * dst) {
2486
+
2487
+ const ggml_tensor * src0 = dst->src[0];
2488
+
2489
+ if (params->ith != 0) {
2490
+ return;
2491
+ }
2492
+
2493
+ assert(ggml_is_contiguous_1(src0));
2494
+ assert(ggml_is_contiguous_1(dst));
2495
+ assert(ggml_are_same_shape(src0, dst));
2496
+
2497
+ const int n = ggml_nrows(src0);
2498
+ const int nc = src0->ne[0];
2499
+
2500
+ float negative_slope;
2501
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
2502
+
2503
+ assert(dst->nb[0] == sizeof(ggml_fp16_t));
2504
+ assert(src0->nb[0] == sizeof(ggml_fp16_t));
2505
+
2506
+ for (int i = 0; i < n; i++) {
2507
+ ggml_vec_leaky_relu_f16(nc,
2508
+ (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
2509
+ (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2510
+ }
2511
+ }
2512
+
2513
+ void ggml_compute_forward_leaky_relu(
2514
+ const ggml_compute_params * params,
2515
+ ggml_tensor * dst) {
2516
+
2517
+ const ggml_tensor * src0 = dst->src[0];
2518
+
2519
+ switch (src0->type) {
2520
+ case GGML_TYPE_F32:
2521
+ {
2522
+ ggml_compute_forward_leaky_relu_f32(params, dst);
2523
+ } break;
2524
+ case GGML_TYPE_F16:
2525
+ {
2526
+ ggml_compute_forward_leaky_relu_f16(params, dst);
2527
+ } break;
2528
+ default:
2529
+ {
2530
+ GGML_ABORT("fatal error");
2531
+ }
2532
+ }
2533
+ }
2534
+
2535
+ // ggml_compute_forward_silu_back
2536
+
2537
+ static void ggml_compute_forward_silu_back_f32(
2538
+ const ggml_compute_params * params,
2539
+ ggml_tensor * dst) {
2540
+
2541
+ const ggml_tensor * grad = dst->src[0];
2542
+ const ggml_tensor * src1 = dst->src[1];
2543
+
2544
+ assert(ggml_is_contiguous_1(grad));
2545
+ assert(ggml_is_contiguous_1(src1));
2546
+ assert(ggml_is_contiguous_1(dst));
2547
+ assert(ggml_are_same_shape(src1, dst));
2548
+ assert(ggml_are_same_shape(src1, grad));
2549
+
2550
+ const int ith = params->ith;
2551
+ const int nth = params->nth;
2552
+
2553
+ const int nc = src1->ne[0];
2554
+ const int nr = ggml_nrows(src1);
2555
+
2556
+ // rows per thread
2557
+ const int dr = (nr + nth - 1)/nth;
2558
+
2559
+ // row range for this thread
2560
+ const int ir0 = dr*ith;
2561
+ const int ir1 = MIN(ir0 + dr, nr);
2562
+
2563
+ for (int i1 = ir0; i1 < ir1; i1++) {
2564
+ ggml_vec_silu_backward_f32(nc,
2565
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
2566
+ (float *) ((char *) src1->data + i1*(src1->nb[1])),
2567
+ (float *) ((char *) grad->data + i1*(grad->nb[1])));
2568
+
2569
+ #ifndef NDEBUG
2570
+ for (int k = 0; k < nc; k++) {
2571
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2572
+ GGML_UNUSED(x);
2573
+ assert(!isnan(x));
2574
+ assert(!isinf(x));
2575
+ }
2576
+ #endif
2577
+ }
2578
+ }
2579
+
2580
+ static void ggml_compute_forward_silu_back_f16(
2581
+ const ggml_compute_params * params,
2582
+ ggml_tensor * dst) {
2583
+
2584
+ const ggml_tensor * grad = dst->src[0];
2585
+ const ggml_tensor * src1 = dst->src[1];
2586
+
2587
+ assert(ggml_is_contiguous_1(grad));
2588
+ assert(ggml_is_contiguous_1(src1));
2589
+ assert(ggml_is_contiguous_1(dst));
2590
+ assert(ggml_are_same_shape(src1, dst));
2591
+ assert(ggml_are_same_shape(src1, grad));
2592
+
2593
+ const int ith = params->ith;
2594
+ const int nth = params->nth;
2595
+
2596
+ const int nc = src1->ne[0];
2597
+ const int nr = ggml_nrows(src1);
2598
+
2599
+ // rows per thread
2600
+ const int dr = (nr + nth - 1)/nth;
2601
+
2602
+ // row range for this thread
2603
+ const int ir0 = dr*ith;
2604
+ const int ir1 = MIN(ir0 + dr, nr);
2605
+
2606
+ for (int i1 = ir0; i1 < ir1; i1++) {
2607
+ ggml_vec_silu_backward_f16(nc,
2608
+ (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2609
+ (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2610
+ (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2611
+
2612
+ #ifndef NDEBUG
2613
+ for (int k = 0; k < nc; k++) {
2614
+ const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2615
+ const float v = GGML_CPU_FP16_TO_FP32(x);
2616
+ GGML_UNUSED(v);
2617
+ assert(!isnan(v));
2618
+ assert(!isinf(v));
2619
+ }
2620
+ #endif
2621
+ }
2622
+ }
2623
+
2624
+ void ggml_compute_forward_silu_back(
2625
+ const ggml_compute_params * params,
2626
+ ggml_tensor * dst) {
2627
+
2628
+ const ggml_tensor * src0 = dst->src[0];
2629
+
2630
+ switch (src0->type) {
2631
+ case GGML_TYPE_F32:
2632
+ {
2633
+ ggml_compute_forward_silu_back_f32(params, dst);
2634
+ } break;
2635
+ case GGML_TYPE_F16:
2636
+ {
2637
+ ggml_compute_forward_silu_back_f16(params, dst);
2638
+ } break;
2639
+ default:
2640
+ {
2641
+ GGML_ABORT("fatal error");
2642
+ }
2643
+ }
2644
+ }
2645
+
2646
+ // ggml_compute_forward_reglu
2647
+
2648
+ static void ggml_compute_forward_reglu_f32(
2649
+ const ggml_compute_params * params,
2650
+ ggml_tensor * dst) {
2651
+
2652
+ const ggml_tensor * src0 = dst->src[0];
2653
+ const ggml_tensor * src1 = dst->src[1];
2654
+ char * src0_d = (char *) src0->data;
2655
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2656
+ const size_t src0_o = src0->nb[1];
2657
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2658
+
2659
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2660
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2661
+
2662
+ if (src1) {
2663
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2664
+ GGML_ASSERT(src0->type == src1->type);
2665
+ }
2666
+
2667
+ const int ith = params->ith;
2668
+ const int nth = params->nth;
2669
+
2670
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2671
+ const int nr = ggml_nrows(src0);
2672
+
2673
+ GGML_ASSERT(dst->ne[0] == nc);
2674
+ GGML_ASSERT(ggml_nrows(dst) == nr);
2675
+
2676
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2677
+
2678
+ // rows per thread
2679
+ const int dr = (nr + nth - 1)/nth;
2680
+
2681
+ // row range for this thread
2682
+ const int ir0 = dr*ith;
2683
+ const int ir1 = MIN(ir0 + dr, nr);
2684
+
2685
+ for (int i1 = ir0; i1 < ir1; i1++) {
2686
+ float * src0_p = (float *) (src0_d + i1*src0_o);
2687
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3014
2688
 
3015
- assert(dst->nb[0] == sizeof(float));
3016
- assert(src0->nb[0] == sizeof(float));
2689
+ if (!src1) {
2690
+ src0_p += swapped ? nc : 0;
2691
+ src1_p += swapped ? 0 : nc;
2692
+ }
3017
2693
 
3018
- for (int i = 0; i < n; i++) {
3019
- ggml_vec_leaky_relu_f32(nc,
3020
- (float *) ((char *) dst->data + i*( dst->nb[1])),
3021
- (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2694
+ ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2695
+
2696
+ #ifndef NDEBUG
2697
+ for (int k = 0; k < nc; k++) {
2698
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2699
+ GGML_UNUSED(x);
2700
+ assert(!isnan(x));
2701
+ assert(!isinf(x));
2702
+ }
2703
+ #endif
3022
2704
  }
3023
2705
  }
3024
2706
 
3025
- static void ggml_compute_forward_leaky_relu_f16(
2707
+ static void ggml_compute_forward_reglu_f16(
3026
2708
  const ggml_compute_params * params,
3027
2709
  ggml_tensor * dst) {
3028
2710
 
3029
2711
  const ggml_tensor * src0 = dst->src[0];
2712
+ const ggml_tensor * src1 = dst->src[1];
2713
+ char * src0_d = (char *) src0->data;
2714
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2715
+ const size_t src0_o = src0->nb[1];
2716
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3030
2717
 
3031
- if (params->ith != 0) {
3032
- return;
2718
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2719
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2720
+
2721
+ if (src1) {
2722
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2723
+ GGML_ASSERT(src0->type == src1->type);
3033
2724
  }
3034
2725
 
3035
- assert(ggml_is_contiguous_1(src0));
3036
- assert(ggml_is_contiguous_1(dst));
3037
- assert(ggml_are_same_shape(src0, dst));
2726
+ const int ith = params->ith;
2727
+ const int nth = params->nth;
3038
2728
 
3039
- const int n = ggml_nrows(src0);
3040
- const int nc = src0->ne[0];
2729
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2730
+ const int nr = ggml_nrows(src0);
3041
2731
 
3042
- float negative_slope;
3043
- memcpy(&negative_slope, dst->op_params, sizeof(float));
2732
+ GGML_ASSERT(dst->ne[0] == nc);
2733
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3044
2734
 
3045
- assert(dst->nb[0] == sizeof(ggml_fp16_t));
3046
- assert(src0->nb[0] == sizeof(ggml_fp16_t));
2735
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3047
2736
 
3048
- for (int i = 0; i < n; i++) {
3049
- ggml_vec_leaky_relu_f16(nc,
3050
- (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
3051
- (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2737
+ // rows per thread
2738
+ const int dr = (nr + nth - 1)/nth;
2739
+
2740
+ // row range for this thread
2741
+ const int ir0 = dr*ith;
2742
+ const int ir1 = MIN(ir0 + dr, nr);
2743
+
2744
+ for (int i1 = ir0; i1 < ir1; i1++) {
2745
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2746
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
2747
+
2748
+ if (!src1) {
2749
+ src0_p += swapped ? nc : 0;
2750
+ src1_p += swapped ? 0 : nc;
2751
+ }
2752
+
2753
+ ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2754
+
2755
+ #ifndef NDEBUG
2756
+ for (int k = 0; k < nc; k++) {
2757
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2758
+ const float v = GGML_FP16_TO_FP32(x);
2759
+ GGML_UNUSED(v);
2760
+ assert(!isnan(v));
2761
+ assert(!isinf(v));
2762
+ }
2763
+ #endif
3052
2764
  }
3053
2765
  }
3054
2766
 
3055
- void ggml_compute_forward_leaky_relu(
2767
+ static void ggml_compute_forward_reglu(
3056
2768
  const ggml_compute_params * params,
3057
2769
  ggml_tensor * dst) {
3058
2770
 
@@ -3061,11 +2773,11 @@ void ggml_compute_forward_leaky_relu(
3061
2773
  switch (src0->type) {
3062
2774
  case GGML_TYPE_F32:
3063
2775
  {
3064
- ggml_compute_forward_leaky_relu_f32(params, dst);
2776
+ ggml_compute_forward_reglu_f32(params, dst);
3065
2777
  } break;
3066
2778
  case GGML_TYPE_F16:
3067
2779
  {
3068
- ggml_compute_forward_leaky_relu_f16(params, dst);
2780
+ ggml_compute_forward_reglu_f16(params, dst);
3069
2781
  } break;
3070
2782
  default:
3071
2783
  {
@@ -3074,26 +2786,37 @@ void ggml_compute_forward_leaky_relu(
3074
2786
  }
3075
2787
  }
3076
2788
 
3077
- // ggml_compute_forward_silu_back
2789
+ // ggml_compute_forward_geglu
3078
2790
 
3079
- static void ggml_compute_forward_silu_back_f32(
2791
+ static void ggml_compute_forward_geglu_f32(
3080
2792
  const ggml_compute_params * params,
3081
2793
  ggml_tensor * dst) {
3082
2794
 
3083
- const ggml_tensor * grad = dst->src[0];
2795
+ const ggml_tensor * src0 = dst->src[0];
3084
2796
  const ggml_tensor * src1 = dst->src[1];
2797
+ char * src0_d = (char *) src0->data;
2798
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2799
+ const size_t src0_o = src0->nb[1];
2800
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3085
2801
 
3086
- assert(ggml_is_contiguous_1(grad));
3087
- assert(ggml_is_contiguous_1(src1));
3088
- assert(ggml_is_contiguous_1(dst));
3089
- assert(ggml_are_same_shape(src1, dst));
3090
- assert(ggml_are_same_shape(src1, grad));
2802
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2803
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2804
+
2805
+ if (src1) {
2806
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2807
+ GGML_ASSERT(src0->type == src1->type);
2808
+ }
3091
2809
 
3092
2810
  const int ith = params->ith;
3093
2811
  const int nth = params->nth;
3094
2812
 
3095
- const int nc = src1->ne[0];
3096
- const int nr = ggml_nrows(src1);
2813
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2814
+ const int nr = ggml_nrows(src0);
2815
+
2816
+ GGML_ASSERT(dst->ne[0] == nc);
2817
+ GGML_ASSERT(ggml_nrows(dst) == nr);
2818
+
2819
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3097
2820
 
3098
2821
  // rows per thread
3099
2822
  const int dr = (nr + nth - 1)/nth;
@@ -3103,10 +2826,15 @@ static void ggml_compute_forward_silu_back_f32(
3103
2826
  const int ir1 = MIN(ir0 + dr, nr);
3104
2827
 
3105
2828
  for (int i1 = ir0; i1 < ir1; i1++) {
3106
- ggml_vec_silu_backward_f32(nc,
3107
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
3108
- (float *) ((char *) src1->data + i1*(src1->nb[1])),
3109
- (float *) ((char *) grad->data + i1*(grad->nb[1])));
2829
+ float * src0_p = (float *) (src0_d + i1*src0_o);
2830
+ float * src1_p = (float *) (src1_d + i1*src1_o);
2831
+
2832
+ if (!src1) {
2833
+ src0_p += swapped ? nc : 0;
2834
+ src1_p += swapped ? 0 : nc;
2835
+ }
2836
+
2837
+ ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3110
2838
 
3111
2839
  #ifndef NDEBUG
3112
2840
  for (int k = 0; k < nc; k++) {
@@ -3119,24 +2847,35 @@ static void ggml_compute_forward_silu_back_f32(
3119
2847
  }
3120
2848
  }
3121
2849
 
3122
- static void ggml_compute_forward_silu_back_f16(
2850
+ static void ggml_compute_forward_geglu_f16(
3123
2851
  const ggml_compute_params * params,
3124
2852
  ggml_tensor * dst) {
3125
2853
 
3126
- const ggml_tensor * grad = dst->src[0];
2854
+ const ggml_tensor * src0 = dst->src[0];
3127
2855
  const ggml_tensor * src1 = dst->src[1];
2856
+ char * src0_d = (char *) src0->data;
2857
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
2858
+ const size_t src0_o = src0->nb[1];
2859
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3128
2860
 
3129
- assert(ggml_is_contiguous_1(grad));
3130
- assert(ggml_is_contiguous_1(src1));
3131
- assert(ggml_is_contiguous_1(dst));
3132
- assert(ggml_are_same_shape(src1, dst));
3133
- assert(ggml_are_same_shape(src1, grad));
2861
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2862
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
2863
+
2864
+ if (src1) {
2865
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
2866
+ GGML_ASSERT(src0->type == src1->type);
2867
+ }
3134
2868
 
3135
2869
  const int ith = params->ith;
3136
2870
  const int nth = params->nth;
3137
2871
 
3138
- const int nc = src1->ne[0];
3139
- const int nr = ggml_nrows(src1);
2872
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2873
+ const int nr = ggml_nrows(src0);
2874
+
2875
+ GGML_ASSERT(dst->ne[0] == nc);
2876
+ GGML_ASSERT(ggml_nrows(dst) == nr);
2877
+
2878
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3140
2879
 
3141
2880
  // rows per thread
3142
2881
  const int dr = (nr + nth - 1)/nth;
@@ -3146,24 +2885,29 @@ static void ggml_compute_forward_silu_back_f16(
3146
2885
  const int ir1 = MIN(ir0 + dr, nr);
3147
2886
 
3148
2887
  for (int i1 = ir0; i1 < ir1; i1++) {
3149
- ggml_vec_silu_backward_f16(nc,
3150
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3151
- (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
3152
- (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2888
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2889
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3153
2890
 
3154
- #ifndef NDEBUG
2891
+ if (!src1) {
2892
+ src0_p += swapped ? nc : 0;
2893
+ src1_p += swapped ? 0 : nc;
2894
+ }
2895
+
2896
+ ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2897
+
2898
+ #ifndef NDEBUG
3155
2899
  for (int k = 0; k < nc; k++) {
3156
- const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3157
- const float v = GGML_CPU_FP16_TO_FP32(x);
2900
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2901
+ const float v = GGML_FP16_TO_FP32(x);
3158
2902
  GGML_UNUSED(v);
3159
2903
  assert(!isnan(v));
3160
2904
  assert(!isinf(v));
3161
2905
  }
3162
- #endif
2906
+ #endif
3163
2907
  }
3164
2908
  }
3165
2909
 
3166
- void ggml_compute_forward_silu_back(
2910
+ static void ggml_compute_forward_geglu(
3167
2911
  const ggml_compute_params * params,
3168
2912
  ggml_tensor * dst) {
3169
2913
 
@@ -3172,11 +2916,11 @@ void ggml_compute_forward_silu_back(
3172
2916
  switch (src0->type) {
3173
2917
  case GGML_TYPE_F32:
3174
2918
  {
3175
- ggml_compute_forward_silu_back_f32(params, dst);
2919
+ ggml_compute_forward_geglu_f32(params, dst);
3176
2920
  } break;
3177
2921
  case GGML_TYPE_F16:
3178
2922
  {
3179
- ggml_compute_forward_silu_back_f16(params, dst);
2923
+ ggml_compute_forward_geglu_f16(params, dst);
3180
2924
  } break;
3181
2925
  default:
3182
2926
  {
@@ -3185,9 +2929,9 @@ void ggml_compute_forward_silu_back(
3185
2929
  }
3186
2930
  }
3187
2931
 
3188
- // ggml_compute_forward_reglu
2932
+ // ggml_compute_forward_swiglu
3189
2933
 
3190
- static void ggml_compute_forward_reglu_f32(
2934
+ static void ggml_compute_forward_swiglu_f32(
3191
2935
  const ggml_compute_params * params,
3192
2936
  ggml_tensor * dst) {
3193
2937
 
@@ -3233,7 +2977,7 @@ static void ggml_compute_forward_reglu_f32(
3233
2977
  src1_p += swapped ? 0 : nc;
3234
2978
  }
3235
2979
 
3236
- ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2980
+ ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3237
2981
 
3238
2982
  #ifndef NDEBUG
3239
2983
  for (int k = 0; k < nc; k++) {
@@ -3246,9 +2990,93 @@ static void ggml_compute_forward_reglu_f32(
3246
2990
  }
3247
2991
  }
3248
2992
 
3249
- static void ggml_compute_forward_reglu_f16(
3250
- const ggml_compute_params * params,
3251
- ggml_tensor * dst) {
2993
+ static void ggml_compute_forward_swiglu_f16(
2994
+ const ggml_compute_params * params,
2995
+ ggml_tensor * dst) {
2996
+
2997
+ const ggml_tensor * src0 = dst->src[0];
2998
+ const ggml_tensor * src1 = dst->src[1];
2999
+ char * src0_d = (char *) src0->data;
3000
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3001
+ const size_t src0_o = src0->nb[1];
3002
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3003
+
3004
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3005
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3006
+
3007
+ if (src1) {
3008
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3009
+ GGML_ASSERT(src0->type == src1->type);
3010
+ }
3011
+
3012
+ const int ith = params->ith;
3013
+ const int nth = params->nth;
3014
+
3015
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3016
+ const int nr = ggml_nrows(src0);
3017
+
3018
+ GGML_ASSERT(dst->ne[0] == nc);
3019
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3020
+
3021
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3022
+
3023
+ // rows per thread
3024
+ const int dr = (nr + nth - 1)/nth;
3025
+
3026
+ // row range for this thread
3027
+ const int ir0 = dr*ith;
3028
+ const int ir1 = MIN(ir0 + dr, nr);
3029
+
3030
+ for (int i1 = ir0; i1 < ir1; i1++) {
3031
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3032
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3033
+
3034
+ if (!src1) {
3035
+ src0_p += swapped ? nc : 0;
3036
+ src1_p += swapped ? 0 : nc;
3037
+ }
3038
+
3039
+ ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3040
+
3041
+ #ifndef NDEBUG
3042
+ for (int k = 0; k < nc; k++) {
3043
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3044
+ const float v = GGML_FP16_TO_FP32(x);
3045
+ GGML_UNUSED(v);
3046
+ assert(!isnan(v));
3047
+ assert(!isinf(v));
3048
+ }
3049
+ #endif
3050
+ }
3051
+ }
3052
+
3053
+ static void ggml_compute_forward_swiglu(
3054
+ const ggml_compute_params * params,
3055
+ ggml_tensor * dst) {
3056
+
3057
+ const ggml_tensor * src0 = dst->src[0];
3058
+
3059
+ switch (src0->type) {
3060
+ case GGML_TYPE_F32:
3061
+ {
3062
+ ggml_compute_forward_swiglu_f32(params, dst);
3063
+ } break;
3064
+ case GGML_TYPE_F16:
3065
+ {
3066
+ ggml_compute_forward_swiglu_f16(params, dst);
3067
+ } break;
3068
+ default:
3069
+ {
3070
+ GGML_ABORT("fatal error");
3071
+ }
3072
+ }
3073
+ }
3074
+
3075
+ // ggml_compute_forward_swiglu_oai
3076
+
3077
+ static void ggml_compute_forward_swiglu_oai_f32(
3078
+ const ggml_compute_params * params,
3079
+ ggml_tensor * dst) {
3252
3080
 
3253
3081
  const ggml_tensor * src0 = dst->src[0];
3254
3082
  const ggml_tensor * src1 = dst->src[1];
@@ -3275,6 +3103,8 @@ static void ggml_compute_forward_reglu_f16(
3275
3103
  GGML_ASSERT(ggml_nrows(dst) == nr);
3276
3104
 
3277
3105
  const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3106
+ const float alpha = ggml_get_op_params_f32(dst, 2);
3107
+ const float limit = ggml_get_op_params_f32(dst, 3);
3278
3108
 
3279
3109
  // rows per thread
3280
3110
  const int dr = (nr + nth - 1)/nth;
@@ -3284,29 +3114,34 @@ static void ggml_compute_forward_reglu_f16(
3284
3114
  const int ir1 = MIN(ir0 + dr, nr);
3285
3115
 
3286
3116
  for (int i1 = ir0; i1 < ir1; i1++) {
3287
- ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3288
- ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3117
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3118
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3119
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3289
3120
 
3290
3121
  if (!src1) {
3291
3122
  src0_p += swapped ? nc : 0;
3292
3123
  src1_p += swapped ? 0 : nc;
3293
3124
  }
3294
3125
 
3295
- ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3126
+ for (int k = 0; k < nc; k++) {
3127
+ const float x = std::min(src0_p[k], limit);
3128
+ const float y = std::clamp(src1_p[k], -limit, limit);
3129
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
3130
+ dst_p[k] = out_glu * (y + 1.f);
3131
+ }
3296
3132
 
3297
3133
  #ifndef NDEBUG
3298
3134
  for (int k = 0; k < nc; k++) {
3299
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3300
- const float v = GGML_FP16_TO_FP32(x);
3301
- GGML_UNUSED(v);
3302
- assert(!isnan(v));
3303
- assert(!isinf(v));
3135
+ const float x = dst_p[k];
3136
+ GGML_UNUSED(x);
3137
+ assert(!isnan(x));
3138
+ assert(!isinf(x));
3304
3139
  }
3305
3140
  #endif
3306
3141
  }
3307
3142
  }
3308
3143
 
3309
- static void ggml_compute_forward_reglu(
3144
+ static void ggml_compute_forward_swiglu_oai(
3310
3145
  const ggml_compute_params * params,
3311
3146
  ggml_tensor * dst) {
3312
3147
 
@@ -3315,11 +3150,7 @@ static void ggml_compute_forward_reglu(
3315
3150
  switch (src0->type) {
3316
3151
  case GGML_TYPE_F32:
3317
3152
  {
3318
- ggml_compute_forward_reglu_f32(params, dst);
3319
- } break;
3320
- case GGML_TYPE_F16:
3321
- {
3322
- ggml_compute_forward_reglu_f16(params, dst);
3153
+ ggml_compute_forward_swiglu_oai_f32(params, dst);
3323
3154
  } break;
3324
3155
  default:
3325
3156
  {
@@ -3328,9 +3159,9 @@ static void ggml_compute_forward_reglu(
3328
3159
  }
3329
3160
  }
3330
3161
 
3331
- // ggml_compute_forward_geglu
3162
+ // ggml_compute_forward_geglu_erf
3332
3163
 
3333
- static void ggml_compute_forward_geglu_f32(
3164
+ static void ggml_compute_forward_geglu_erf_f32(
3334
3165
  const ggml_compute_params * params,
3335
3166
  ggml_tensor * dst) {
3336
3167
 
@@ -3376,7 +3207,7 @@ static void ggml_compute_forward_geglu_f32(
3376
3207
  src1_p += swapped ? 0 : nc;
3377
3208
  }
3378
3209
 
3379
- ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3210
+ ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3380
3211
 
3381
3212
  #ifndef NDEBUG
3382
3213
  for (int k = 0; k < nc; k++) {
@@ -3389,7 +3220,7 @@ static void ggml_compute_forward_geglu_f32(
3389
3220
  }
3390
3221
  }
3391
3222
 
3392
- static void ggml_compute_forward_geglu_f16(
3223
+ static void ggml_compute_forward_geglu_erf_f16(
3393
3224
  const ggml_compute_params * params,
3394
3225
  ggml_tensor * dst) {
3395
3226
 
@@ -3435,7 +3266,7 @@ static void ggml_compute_forward_geglu_f16(
3435
3266
  src1_p += swapped ? 0 : nc;
3436
3267
  }
3437
3268
 
3438
- ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3269
+ ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3439
3270
 
3440
3271
  #ifndef NDEBUG
3441
3272
  for (int k = 0; k < nc; k++) {
@@ -3449,7 +3280,7 @@ static void ggml_compute_forward_geglu_f16(
3449
3280
  }
3450
3281
  }
3451
3282
 
3452
- static void ggml_compute_forward_geglu(
3283
+ static void ggml_compute_forward_geglu_erf(
3453
3284
  const ggml_compute_params * params,
3454
3285
  ggml_tensor * dst) {
3455
3286
 
@@ -3458,11 +3289,11 @@ static void ggml_compute_forward_geglu(
3458
3289
  switch (src0->type) {
3459
3290
  case GGML_TYPE_F32:
3460
3291
  {
3461
- ggml_compute_forward_geglu_f32(params, dst);
3292
+ ggml_compute_forward_geglu_erf_f32(params, dst);
3462
3293
  } break;
3463
3294
  case GGML_TYPE_F16:
3464
3295
  {
3465
- ggml_compute_forward_geglu_f16(params, dst);
3296
+ ggml_compute_forward_geglu_erf_f16(params, dst);
3466
3297
  } break;
3467
3298
  default:
3468
3299
  {
@@ -3471,9 +3302,9 @@ static void ggml_compute_forward_geglu(
3471
3302
  }
3472
3303
  }
3473
3304
 
3474
- // ggml_compute_forward_swiglu
3305
+ // ggml_compute_forward_geglu_quick
3475
3306
 
3476
- static void ggml_compute_forward_swiglu_f32(
3307
+ static void ggml_compute_forward_geglu_quick_f32(
3477
3308
  const ggml_compute_params * params,
3478
3309
  ggml_tensor * dst) {
3479
3310
 
@@ -3519,7 +3350,7 @@ static void ggml_compute_forward_swiglu_f32(
3519
3350
  src1_p += swapped ? 0 : nc;
3520
3351
  }
3521
3352
 
3522
- ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3353
+ ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3523
3354
 
3524
3355
  #ifndef NDEBUG
3525
3356
  for (int k = 0; k < nc; k++) {
@@ -3532,7 +3363,7 @@ static void ggml_compute_forward_swiglu_f32(
3532
3363
  }
3533
3364
  }
3534
3365
 
3535
- static void ggml_compute_forward_swiglu_f16(
3366
+ static void ggml_compute_forward_geglu_quick_f16(
3536
3367
  const ggml_compute_params * params,
3537
3368
  ggml_tensor * dst) {
3538
3369
 
@@ -3578,7 +3409,7 @@ static void ggml_compute_forward_swiglu_f16(
3578
3409
  src1_p += swapped ? 0 : nc;
3579
3410
  }
3580
3411
 
3581
- ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3412
+ ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3582
3413
 
3583
3414
  #ifndef NDEBUG
3584
3415
  for (int k = 0; k < nc; k++) {
@@ -3592,7 +3423,7 @@ static void ggml_compute_forward_swiglu_f16(
3592
3423
  }
3593
3424
  }
3594
3425
 
3595
- static void ggml_compute_forward_swiglu(
3426
+ static void ggml_compute_forward_geglu_quick(
3596
3427
  const ggml_compute_params * params,
3597
3428
  ggml_tensor * dst) {
3598
3429
 
@@ -3601,11 +3432,11 @@ static void ggml_compute_forward_swiglu(
3601
3432
  switch (src0->type) {
3602
3433
  case GGML_TYPE_F32:
3603
3434
  {
3604
- ggml_compute_forward_swiglu_f32(params, dst);
3435
+ ggml_compute_forward_geglu_quick_f32(params, dst);
3605
3436
  } break;
3606
3437
  case GGML_TYPE_F16:
3607
3438
  {
3608
- ggml_compute_forward_swiglu_f16(params, dst);
3439
+ ggml_compute_forward_geglu_quick_f16(params, dst);
3609
3440
  } break;
3610
3441
  default:
3611
3442
  {
@@ -3729,6 +3560,9 @@ static void ggml_compute_forward_rms_norm_f32(
3729
3560
 
3730
3561
  const float scale = 1.0f/sqrtf(mean + eps);
3731
3562
 
3563
+ // if you hit this, likely you got an inf somewhere earlier
3564
+ assert(scale > 0.0f);
3565
+
3732
3566
  ggml_vec_scale_f32(ne00, y, scale);
3733
3567
  }
3734
3568
  }
@@ -4310,6 +4144,7 @@ void ggml_compute_forward_out_prod(
4310
4144
  case GGML_TYPE_Q5_0:
4311
4145
  case GGML_TYPE_Q5_1:
4312
4146
  case GGML_TYPE_Q8_0:
4147
+ case GGML_TYPE_MXFP4:
4313
4148
  case GGML_TYPE_Q2_K:
4314
4149
  case GGML_TYPE_Q3_K:
4315
4150
  case GGML_TYPE_Q4_K:
@@ -4357,9 +4192,11 @@ static void ggml_compute_forward_scale_f32(
4357
4192
  GGML_ASSERT(ggml_is_contiguous(dst));
4358
4193
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
4359
4194
 
4360
- // scale factor
4361
- float v;
4362
- memcpy(&v, dst->op_params, sizeof(float));
4195
+ float s; // scale factor
4196
+ float b; // bias
4197
+
4198
+ memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4199
+ memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4363
4200
 
4364
4201
  const int ith = params->ith;
4365
4202
  const int nth = params->nth;
@@ -4378,12 +4215,22 @@ static void ggml_compute_forward_scale_f32(
4378
4215
 
4379
4216
  const size_t nb1 = dst->nb[1];
4380
4217
 
4381
- for (int i1 = ir0; i1 < ir1; i1++) {
4382
- if (dst->data != src0->data) {
4383
- // src0 is same shape as dst => same indices
4384
- memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4218
+ if (b == 0.0f) {
4219
+ for (int i1 = ir0; i1 < ir1; i1++) {
4220
+ if (dst->data != src0->data) {
4221
+ // src0 is same shape as dst => same indices
4222
+ // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4223
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4224
+ }
4225
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4226
+ }
4227
+ } else {
4228
+ for (int i1 = ir0; i1 < ir1; i1++) {
4229
+ ggml_vec_mad1_f32(nc,
4230
+ (float *) ((char *) dst->data + i1*nb1),
4231
+ (float *) ((char *) src0->data + i1*nb1),
4232
+ s, b);
4385
4233
  }
4386
- ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
4387
4234
  }
4388
4235
  }
4389
4236
 
@@ -4572,6 +4419,7 @@ void ggml_compute_forward_set(
4572
4419
  case GGML_TYPE_Q5_1:
4573
4420
  case GGML_TYPE_Q8_0:
4574
4421
  case GGML_TYPE_Q8_1:
4422
+ case GGML_TYPE_MXFP4:
4575
4423
  case GGML_TYPE_Q2_K:
4576
4424
  case GGML_TYPE_Q3_K:
4577
4425
  case GGML_TYPE_Q4_K:
@@ -4833,6 +4681,7 @@ void ggml_compute_forward_get_rows(
4833
4681
  case GGML_TYPE_Q5_1:
4834
4682
  case GGML_TYPE_Q8_0:
4835
4683
  case GGML_TYPE_Q8_1:
4684
+ case GGML_TYPE_MXFP4:
4836
4685
  case GGML_TYPE_Q2_K:
4837
4686
  case GGML_TYPE_Q3_K:
4838
4687
  case GGML_TYPE_Q4_K:
@@ -4890,6 +4739,7 @@ void ggml_compute_forward_get_rows(
4890
4739
  //}
4891
4740
  }
4892
4741
 
4742
+ template<typename idx_t>
4893
4743
  static void ggml_compute_forward_set_rows_f32(
4894
4744
  const ggml_compute_params * params,
4895
4745
  ggml_tensor * dst) {
@@ -4928,7 +4778,7 @@ static void ggml_compute_forward_set_rows_f32(
4928
4778
  const int64_t i11 = i02%ne11;
4929
4779
  const int64_t i10 = i;
4930
4780
 
4931
- const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4781
+ const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4932
4782
 
4933
4783
  GGML_ASSERT(i1 >= 0 && i1 < ne1);
4934
4784
 
@@ -4945,11 +4795,18 @@ void ggml_compute_forward_set_rows(
4945
4795
  ggml_tensor * dst) {
4946
4796
 
4947
4797
  const ggml_tensor * src0 = dst->src[0];
4798
+ const ggml_tensor * src1 = dst->src[1];
4948
4799
 
4949
4800
  switch (src0->type) {
4950
4801
  case GGML_TYPE_F32:
4951
4802
  {
4952
- ggml_compute_forward_set_rows_f32(params, dst);
4803
+ if (src1->type == GGML_TYPE_I64) {
4804
+ ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4805
+ } else if (src1->type == GGML_TYPE_I32) {
4806
+ ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4807
+ } else {
4808
+ GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4809
+ }
4953
4810
  } break;
4954
4811
  default:
4955
4812
  {
@@ -5222,6 +5079,7 @@ static void ggml_compute_forward_soft_max_f32(
5222
5079
 
5223
5080
  const ggml_tensor * src0 = dst->src[0];
5224
5081
  const ggml_tensor * src1 = dst->src[1];
5082
+ const ggml_tensor * src2 = dst->src[2];
5225
5083
 
5226
5084
  assert(ggml_is_contiguous(dst));
5227
5085
  assert(ggml_are_same_shape(src0, dst));
@@ -5232,14 +5090,17 @@ static void ggml_compute_forward_soft_max_f32(
5232
5090
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5233
5091
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5234
5092
 
5235
- // TODO: handle transposed/permuted matrices
5236
-
5237
5093
  const int ith = params->ith;
5238
5094
  const int nth = params->nth;
5239
5095
 
5240
5096
  GGML_TENSOR_UNARY_OP_LOCALS
5241
5097
 
5242
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
5098
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5099
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5100
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5101
+
5102
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5103
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
5243
5104
 
5244
5105
  // TODO: is this supposed to be ceil instead of floor?
5245
5106
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5110,78 @@ static void ggml_compute_forward_soft_max_f32(
5249
5110
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5250
5111
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5251
5112
 
5252
- const int nc = src0->ne[0];
5253
- const int nr = ggml_nrows(src0);
5254
-
5255
- // rows per thread
5256
- const int dr = (nr + nth - 1)/nth;
5257
-
5258
- // row range for this thread
5259
- const int ir0 = dr*ith;
5260
- const int ir1 = MIN(ir0 + dr, nr);
5261
-
5262
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5113
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5263
5114
 
5264
5115
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5265
5116
 
5266
- for (int i1 = ir0; i1 < ir1; i1++) {
5267
- // ALiBi
5268
- const uint32_t h = (i1/ne01)%ne02; // head
5269
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5270
-
5271
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
5272
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
5117
+ // sinks
5118
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5273
5119
 
5274
- // broadcast the mask across rows
5275
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5276
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5277
-
5278
- ggml_vec_cpy_f32 (nc, wp, sp);
5279
- ggml_vec_scale_f32(nc, wp, scale);
5280
- if (mp_f32) {
5281
- if (use_f16) {
5282
- for (int i = 0; i < nc; ++i) {
5283
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5284
- }
5285
- } else {
5286
- for (int i = 0; i < nc; ++i) {
5287
- wp[i] += slope*mp_f32[i];
5120
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5121
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5122
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5123
+ const int64_t i11 = i01;
5124
+ const int64_t i12 = i02%ne12;
5125
+ const int64_t i13 = i03%ne13;
5126
+
5127
+ // ALiBi
5128
+ const uint32_t h = i02; // head
5129
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5130
+
5131
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5132
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5133
+
5134
+ // broadcast the mask across rows
5135
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5136
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5137
+
5138
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5139
+ ggml_vec_scale_f32(ne00, wp, scale);
5140
+ if (mp_f32) {
5141
+ if (use_f16) {
5142
+ for (int i = 0; i < ne00; ++i) {
5143
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5144
+ }
5145
+ } else {
5146
+ for (int i = 0; i < ne00; ++i) {
5147
+ wp[i] += slope*mp_f32[i];
5148
+ }
5149
+ }
5288
5150
  }
5289
- }
5290
- }
5291
5151
 
5292
5152
  #ifndef NDEBUG
5293
- for (int i = 0; i < nc; ++i) {
5294
- //printf("p[%d] = %f\n", i, p[i]);
5295
- assert(!isnan(wp[i]));
5296
- }
5153
+ for (int i = 0; i < ne00; ++i) {
5154
+ //printf("p[%d] = %f\n", i, p[i]);
5155
+ assert(!isnan(wp[i]));
5156
+ }
5297
5157
  #endif
5298
5158
 
5299
- float max = -INFINITY;
5300
- ggml_vec_max_f32(nc, &max, wp);
5159
+ float max = -INFINITY;
5160
+ ggml_vec_max_f32(ne00, &max, wp);
5301
5161
 
5302
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
5303
- assert(sum > 0.0);
5162
+ // if we have sinks, make a correction as if they were included in the softmax
5163
+ if (sk) {
5164
+ max = MAX(max, sk[i02]);
5165
+ }
5166
+
5167
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5168
+ assert(sum > 0.0);
5169
+
5170
+ if (sk) {
5171
+ sum += (ggml_float) expf(sk[i02] - max);
5172
+ }
5304
5173
 
5305
- sum = 1.0/sum;
5306
- ggml_vec_scale_f32(nc, dp, sum);
5174
+ sum = 1.0/sum;
5175
+ ggml_vec_scale_f32(ne00, dp, sum);
5307
5176
 
5308
5177
  #ifndef NDEBUG
5309
- for (int i = 0; i < nc; ++i) {
5310
- assert(!isnan(dp[i]));
5311
- assert(!isinf(dp[i]));
5312
- }
5178
+ for (int i = 0; i < ne00; ++i) {
5179
+ assert(!isnan(dp[i]));
5180
+ assert(!isinf(dp[i]));
5181
+ }
5313
5182
  #endif
5183
+ }
5184
+ }
5314
5185
  }
5315
5186
  }
5316
5187
 
@@ -5534,6 +5405,7 @@ void ggml_compute_forward_clamp(
5534
5405
  case GGML_TYPE_Q5_1:
5535
5406
  case GGML_TYPE_Q8_0:
5536
5407
  case GGML_TYPE_Q8_1:
5408
+ case GGML_TYPE_MXFP4:
5537
5409
  case GGML_TYPE_Q2_K:
5538
5410
  case GGML_TYPE_Q3_K:
5539
5411
  case GGML_TYPE_Q4_K:
@@ -6460,7 +6332,195 @@ void ggml_compute_forward_im2col_back_f32(
6460
6332
  const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6461
6333
  const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6462
6334
 
6463
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
6335
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6336
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6337
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6338
+
6339
+ GGML_TENSOR_BINARY_OP_LOCALS;
6340
+
6341
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6342
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6343
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6344
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6345
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6346
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6347
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6348
+
6349
+ const int ith = params->ith;
6350
+ const int nth = params->nth;
6351
+
6352
+ const int64_t N = is_2D ? ne3 : ne2;
6353
+ const int64_t IC = is_2D ? ne2 : ne1;
6354
+ const int64_t IH = is_2D ? ne1 : 1;
6355
+ const int64_t IW = ne0;
6356
+
6357
+ const int64_t KH = is_2D ? ne11 : 1;
6358
+ const int64_t KW = ne10;
6359
+
6360
+ const int64_t OH = is_2D ? ne02 : 1;
6361
+ const int64_t OW = ne01;
6362
+
6363
+ int ofs0 = is_2D ? nb3 : nb2;
6364
+ int ofs1 = is_2D ? nb2 : nb1;
6365
+
6366
+ GGML_ASSERT(nb0 == sizeof(float));
6367
+
6368
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6369
+ {
6370
+ float * const wdata = (float *) dst->data;
6371
+
6372
+ for (int64_t in = 0; in < N; in++) {
6373
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6374
+ for (int64_t iih = 0; iih < IH; iih++) {
6375
+ for (int64_t iiw = 0; iiw < IW; iiw++) {
6376
+
6377
+ // micro kernel
6378
+ float grad = 0.0f;
6379
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6380
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6381
+ // For s0 > 1 some values were skipped over in the forward pass.
6382
+ // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6383
+ const int64_t tmpw = (iiw + p0 - ikw*d0);
6384
+ if (tmpw % s0 != 0) {
6385
+ continue;
6386
+ }
6387
+ const int64_t iow = tmpw / s0;
6388
+
6389
+ // Equivalent logic as above except for s1.
6390
+ int64_t ioh;
6391
+ if (is_2D) {
6392
+ const int64_t tmph = iih + p1 - ikh*d1;
6393
+
6394
+ if (tmph % s1 != 0) {
6395
+ continue;
6396
+ }
6397
+
6398
+ ioh = tmph / s1;
6399
+ } else {
6400
+ ioh = 0;
6401
+ }
6402
+
6403
+ if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6404
+ continue;
6405
+ }
6406
+
6407
+ const float * const grad_in = (const float *) src0->data
6408
+ + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6409
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6410
+ }
6411
+ }
6412
+ float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6413
+ dst_data[iih*IW + iiw] = grad;
6414
+ }
6415
+ }
6416
+ }
6417
+ }
6418
+ }
6419
+ }
6420
+
6421
+
6422
+ // ggml_compute_forward_im2col_3d_f16
6423
+ // src0: kernel [OC*IC, KD, KH, KW]
6424
+ // src1: image [N*IC, ID, IH, IW]
6425
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6426
+ static void ggml_compute_forward_im2col_3d_f16(
6427
+ const ggml_compute_params * params,
6428
+ ggml_tensor * dst) {
6429
+
6430
+ const ggml_tensor * src0 = dst->src[0];
6431
+ const ggml_tensor * src1 = dst->src[1];
6432
+
6433
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6434
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6435
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
6436
+
6437
+ GGML_TENSOR_BINARY_OP_LOCALS;
6438
+
6439
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6440
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6441
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6442
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6443
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6444
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6445
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6446
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6447
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6448
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6449
+
6450
+
6451
+ const int ith = params->ith;
6452
+ const int nth = params->nth;
6453
+
6454
+ const int64_t N = ne13 / IC;
6455
+ const int64_t ID = ne12;
6456
+ const int64_t IH = ne11;
6457
+ const int64_t IW = ne10;
6458
+
6459
+ const int64_t OC = ne03 / IC;
6460
+ GGML_UNUSED(OC);
6461
+ const int64_t KD = ne02;
6462
+ const int64_t KH = ne01;
6463
+ const int64_t KW = ne00;
6464
+
6465
+ const int64_t OD = ne3 / N;
6466
+ const int64_t OH = ne2;
6467
+ const int64_t OW = ne1;
6468
+ const int64_t OH_OW = OH*OW;
6469
+ const int64_t KD_KH_KW = KD*KH*KW;
6470
+ const int64_t KH_KW = KH*KW;
6471
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6472
+
6473
+ GGML_ASSERT(nb10 == sizeof(float));
6474
+
6475
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6476
+ {
6477
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6478
+
6479
+ for (int64_t in = 0; in < N; in++) {
6480
+ for (int64_t iod = 0; iod < OD; iod++) {
6481
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6482
+ for (int64_t iow = 0; iow < OW; iow++) {
6483
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6484
+
6485
+ // micro kernel
6486
+ ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6487
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6488
+
6489
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6490
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6491
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6492
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6493
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6494
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6495
+
6496
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6497
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6498
+ } else {
6499
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6500
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6501
+ }
6502
+ }
6503
+ }
6504
+ }
6505
+ }
6506
+ }
6507
+ }
6508
+ }
6509
+ }
6510
+ }
6511
+ }
6512
+
6513
+ // ggml_compute_forward_im2col_3d_f32
6514
+ // src0: kernel [OC*IC, KD, KH, KW]
6515
+ // src1: image [N*IC, ID, IH, IW]
6516
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6517
+ static void ggml_compute_forward_im2col_3d_f32(
6518
+ const ggml_compute_params * params,
6519
+ ggml_tensor * dst) {
6520
+
6521
+ const ggml_tensor * src0 = dst->src[0];
6522
+ const ggml_tensor * src1 = dst->src[1];
6523
+
6464
6524
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6465
6525
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
6466
6526
 
@@ -6468,77 +6528,72 @@ void ggml_compute_forward_im2col_back_f32(
6468
6528
 
6469
6529
  const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6470
6530
  const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6471
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6472
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6473
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6474
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6475
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6531
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6532
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6533
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6534
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6535
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6536
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6537
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6538
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6539
+
6476
6540
 
6477
6541
  const int ith = params->ith;
6478
6542
  const int nth = params->nth;
6479
6543
 
6480
- const int64_t N = is_2D ? ne3 : ne2;
6481
- const int64_t IC = is_2D ? ne2 : ne1;
6482
- const int64_t IH = is_2D ? ne1 : 1;
6483
- const int64_t IW = ne0;
6544
+ const int64_t N = ne13 / IC;
6545
+ const int64_t ID = ne12;
6546
+ const int64_t IH = ne11;
6547
+ const int64_t IW = ne10;
6484
6548
 
6485
- const int64_t KH = is_2D ? ne11 : 1;
6486
- const int64_t KW = ne10;
6549
+ const int64_t OC = ne03 / IC;
6550
+ GGML_UNUSED(OC);
6551
+ const int64_t KD = ne02;
6552
+ const int64_t KH = ne01;
6553
+ const int64_t KW = ne00;
6487
6554
 
6488
- const int64_t OH = is_2D ? ne02 : 1;
6489
- const int64_t OW = ne01;
6555
+ const int64_t OD = ne3 / N;
6556
+ const int64_t OH = ne2;
6557
+ const int64_t OW = ne1;
6490
6558
 
6491
- int ofs0 = is_2D ? nb3 : nb2;
6492
- int ofs1 = is_2D ? nb2 : nb1;
6559
+ const int64_t OH_OW = OH*OW;
6560
+ const int64_t KD_KH_KW = KD*KH*KW;
6561
+ const int64_t KH_KW = KH*KW;
6562
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6493
6563
 
6494
- GGML_ASSERT(nb0 == sizeof(float));
6564
+ GGML_ASSERT(nb10 == sizeof(float));
6495
6565
 
6496
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6566
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6497
6567
  {
6498
6568
  float * const wdata = (float *) dst->data;
6499
6569
 
6500
6570
  for (int64_t in = 0; in < N; in++) {
6501
- for (int64_t iic = ith; iic < IC; iic += nth) {
6502
- for (int64_t iih = 0; iih < IH; iih++) {
6503
- for (int64_t iiw = 0; iiw < IW; iiw++) {
6504
-
6505
- // micro kernel
6506
- float grad = 0.0f;
6507
- for (int64_t ikh = 0; ikh < KH; ikh++) {
6508
- for (int64_t ikw = 0; ikw < KW; ikw++) {
6509
- // For s0 > 1 some values were skipped over in the forward pass.
6510
- // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6511
- const int64_t tmpw = (iiw + p0 - ikw*d0);
6512
- if (tmpw % s0 != 0) {
6513
- continue;
6514
- }
6515
- const int64_t iow = tmpw / s0;
6516
-
6517
- // Equivalent logic as above except for s1.
6518
- int64_t ioh;
6519
- if (is_2D) {
6520
- const int64_t tmph = iih + p1 - ikh*d1;
6521
-
6522
- if (tmph % s1 != 0) {
6523
- continue;
6571
+ for (int64_t iod = 0; iod < OD; iod++) {
6572
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6573
+ for (int64_t iow = 0; iow < OW; iow++) {
6574
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6575
+
6576
+ // micro kernel
6577
+ float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6578
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6579
+
6580
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6581
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6582
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6583
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6584
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6585
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6586
+
6587
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6588
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6589
+ } else {
6590
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6591
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6592
+ }
6524
6593
  }
6525
-
6526
- ioh = tmph / s1;
6527
- } else {
6528
- ioh = 0;
6529
- }
6530
-
6531
- if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6532
- continue;
6533
6594
  }
6534
-
6535
- const float * const grad_in = (const float *) src0->data
6536
- + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6537
- grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6538
6595
  }
6539
6596
  }
6540
- float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6541
- dst_data[iih*IW + iiw] = grad;
6542
6597
  }
6543
6598
  }
6544
6599
  }
@@ -6546,6 +6601,26 @@ void ggml_compute_forward_im2col_back_f32(
6546
6601
  }
6547
6602
  }
6548
6603
 
6604
+
6605
+ void ggml_compute_forward_im2col_3d(
6606
+ const ggml_compute_params * params,
6607
+ ggml_tensor * dst) {
6608
+ switch (dst->type) {
6609
+ case GGML_TYPE_F16:
6610
+ {
6611
+ ggml_compute_forward_im2col_3d_f16(params, dst);
6612
+ } break;
6613
+ case GGML_TYPE_F32:
6614
+ {
6615
+ ggml_compute_forward_im2col_3d_f32(params, dst);
6616
+ } break;
6617
+ default:
6618
+ {
6619
+ GGML_ABORT("fatal error");
6620
+ }
6621
+ }
6622
+ }
6623
+
6549
6624
  static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550
6625
  void * a, void * b, float * c) {
6551
6626
  const ggml_type_traits * traits = ggml_get_type_traits(type);
@@ -6726,6 +6801,148 @@ void ggml_compute_forward_conv_2d(
6726
6801
  ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6727
6802
  }
6728
6803
 
6804
+ // ggml_compute_forward_conv_3d
6805
+
6806
+ static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
6807
+ const ggml_tensor * kernel,
6808
+ const ggml_tensor * src,
6809
+ ggml_tensor * dst,
6810
+ ggml_type kernel_type) {
6811
+
6812
+ GGML_ASSERT(ggml_is_contiguous(kernel));
6813
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6814
+ GGML_ASSERT(kernel->type == kernel_type);
6815
+
6816
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6817
+
6818
+ const int32_t s0 = dst->op_params[0];
6819
+ const int32_t s1 = dst->op_params[1];
6820
+ const int32_t s2 = dst->op_params[2];
6821
+ const int32_t p0 = dst->op_params[3];
6822
+ const int32_t p1 = dst->op_params[4];
6823
+ const int32_t p2 = dst->op_params[5];
6824
+ const int32_t d0 = dst->op_params[6];
6825
+ const int32_t d1 = dst->op_params[7];
6826
+ const int32_t d2 = dst->op_params[8];
6827
+ const int32_t c = dst->op_params[9];
6828
+ const int32_t n = dst->op_params[10];
6829
+ const int32_t oc = dst->op_params[11];
6830
+
6831
+ const int64_t src_w = src->ne[0];
6832
+ const int64_t src_h = src->ne[1];
6833
+ const int64_t src_d = src->ne[2];
6834
+ const int64_t knl_w = kernel->ne[0];
6835
+ const int64_t knl_h = kernel->ne[1];
6836
+ const int64_t knl_d = kernel->ne[2];
6837
+ const int64_t dst_w = dst->ne[0];
6838
+ const int64_t dst_h = dst->ne[1];
6839
+ const int64_t dst_d = dst->ne[2];
6840
+
6841
+ const float * src_data = (float *) src->data;
6842
+ void * knl_data = kernel->data;
6843
+ float * dst_data = (float *) dst->data;
6844
+
6845
+ const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6846
+ const int64_t knl_n_total = knl_n_per_channel * c;
6847
+ const int64_t patch_total = n * dst_w * dst_h * dst_d;
6848
+
6849
+ const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
6850
+ const int64_t batch_size = params->wsize / space_per_patch;
6851
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6852
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6853
+
6854
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6855
+
6856
+ void * tmp = params->wdata;
6857
+
6858
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6859
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6860
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
6861
+ const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
6862
+
6863
+ const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6864
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6865
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6866
+
6867
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6868
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6869
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6870
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6871
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6872
+ const int64_t dst_y = p_in_depth / dst_w;
6873
+ const int64_t dst_x = p_in_depth % dst_w;
6874
+
6875
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6876
+
6877
+ for (int64_t ic = 0; ic < c; ++ic) {
6878
+ for (int64_t kz = 0; kz < knl_d; ++kz) {
6879
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6880
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6881
+ const int64_t sz = dst_z * s2 + kz * d2 - p2;
6882
+ const int64_t sy = dst_y * s1 + ky * d1 - p1;
6883
+ const int64_t sx = dst_x * s0 + kx * d0 - p0;
6884
+
6885
+ int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6886
+
6887
+ float src_val;
6888
+ if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6889
+ src_val = 0.0f;
6890
+ } else {
6891
+ const int64_t cn_idx = batch_idx * c + ic;
6892
+ const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6893
+ src_val = *src_ptr;
6894
+ }
6895
+
6896
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6897
+ if (kernel_type == GGML_TYPE_F32) {
6898
+ *(float *)element_ptr = src_val;
6899
+ } else if (kernel_type == GGML_TYPE_F16) {
6900
+ *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6901
+ }
6902
+ }
6903
+ }
6904
+ }
6905
+ }
6906
+ }
6907
+
6908
+ ggml_barrier(params->threadpool);
6909
+
6910
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6911
+ ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
6912
+
6913
+ ggml_barrier(params->threadpool);
6914
+
6915
+ const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6916
+ const int64_t permute_start = params->ith * permute_per_thread;
6917
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
6918
+
6919
+ for (int64_t i = permute_start; i < permute_end; ++i) {
6920
+ const int64_t p = patch_start_batch + i;
6921
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6922
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6923
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6924
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6925
+ const int64_t dst_y = p_in_depth / dst_w;
6926
+ const int64_t dst_x = p_in_depth % dst_w;
6927
+
6928
+ for (int64_t ioc = 0; ioc < oc; ++ioc) {
6929
+ const float value = gemm_output[i * oc + ioc];
6930
+ const int64_t ocn_idx = batch_idx * oc + ioc;
6931
+ float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6932
+ *dst_ptr = value;
6933
+ }
6934
+ }
6935
+ }
6936
+ }
6937
+
6938
+ void ggml_compute_forward_conv_3d(
6939
+ const ggml_compute_params * params,
6940
+ ggml_tensor * dst) {
6941
+ const ggml_tensor * src0 = dst->src[0];
6942
+ const ggml_tensor * src1 = dst->src[1];
6943
+ ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6944
+ }
6945
+
6729
6946
  // ggml_compute_forward_conv_transpose_2d
6730
6947
 
6731
6948
  void ggml_compute_forward_conv_transpose_2d(
@@ -7391,6 +7608,15 @@ static void ggml_compute_forward_pad_f32(
7391
7608
  GGML_TENSOR_UNARY_OP_LOCALS
7392
7609
 
7393
7610
  float * dst_ptr = (float *) dst->data;
7611
+ const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
7612
+ const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
7613
+ const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
7614
+ const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
7615
+ const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
7616
+ const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
7617
+ const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7618
+ const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7619
+
7394
7620
 
7395
7621
  // TODO: optimize
7396
7622
 
@@ -7399,10 +7625,12 @@ static void ggml_compute_forward_pad_f32(
7399
7625
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
7400
7626
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
7401
7627
  const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7402
-
7403
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
7404
-
7405
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
7628
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7629
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7630
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7631
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7632
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7633
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7406
7634
  dst_ptr[dst_idx] = *src_ptr;
7407
7635
  } else {
7408
7636
  dst_ptr[dst_idx] = 0;
@@ -7601,7 +7829,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
7601
7829
  embed_data[j + half] = sinf(arg);
7602
7830
  }
7603
7831
  if (dim % 2 != 0 && ith == 0) {
7604
- embed_data[dim] = 0.f;
7832
+ embed_data[2 * half] = 0.f;
7605
7833
  }
7606
7834
  }
7607
7835
  }
@@ -7687,12 +7915,14 @@ void ggml_compute_forward_argsort(
7687
7915
 
7688
7916
  static void ggml_compute_forward_flash_attn_ext_f16(
7689
7917
  const ggml_compute_params * params,
7690
- const ggml_tensor * q,
7691
- const ggml_tensor * k,
7692
- const ggml_tensor * v,
7693
- const ggml_tensor * mask,
7694
7918
  ggml_tensor * dst) {
7695
7919
 
7920
+ const ggml_tensor * q = dst->src[0];
7921
+ const ggml_tensor * k = dst->src[1];
7922
+ const ggml_tensor * v = dst->src[2];
7923
+ const ggml_tensor * mask = dst->src[3];
7924
+ const ggml_tensor * sinks = dst->src[4];
7925
+
7696
7926
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7697
7927
  GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
7698
7928
  GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -7766,7 +7996,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7766
7996
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7767
7997
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7768
7998
 
7769
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7999
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7770
8000
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7771
8001
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7772
8002
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7798,7 +8028,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7798
8028
  memset(VKQ32, 0, DV*sizeof(float));
7799
8029
  }
7800
8030
 
7801
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
8031
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
7802
8032
 
7803
8033
  // k indices
7804
8034
  const int ik3 = iq3 / rk3;
@@ -7887,6 +8117,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7887
8117
  }
7888
8118
  }
7889
8119
 
8120
+ // sinks
8121
+ if (sinks) {
8122
+ const float s = ((float *)((char *) sinks->data))[h];
8123
+
8124
+ float ms = 1.0f;
8125
+ float vs = 1.0f;
8126
+
8127
+ if (s > M) {
8128
+ ms = expf(M - s);
8129
+ ggml_vec_scale_f32(DV, VKQ32, ms);
8130
+ } else {
8131
+ vs = expf(s - M);
8132
+ }
8133
+
8134
+ S = S*ms + vs;
8135
+ }
8136
+
7890
8137
  // V /= S
7891
8138
  const float S_inv = 1.0f/S;
7892
8139
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -7906,17 +8153,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7906
8153
 
7907
8154
  void ggml_compute_forward_flash_attn_ext(
7908
8155
  const ggml_compute_params * params,
7909
- const ggml_tensor * q,
7910
- const ggml_tensor * k,
7911
- const ggml_tensor * v,
7912
- const ggml_tensor * mask,
7913
8156
  ggml_tensor * dst) {
7914
8157
  switch (dst->op_params[3]) {
7915
8158
  case GGML_PREC_DEFAULT:
7916
8159
  case GGML_PREC_F32:
7917
8160
  {
7918
8161
  // uses F32 accumulators
7919
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
8162
+ ggml_compute_forward_flash_attn_ext_f16(params, dst);
7920
8163
  } break;
7921
8164
  default:
7922
8165
  {
@@ -8336,120 +8579,214 @@ void ggml_compute_forward_ssm_conv(
8336
8579
  static void ggml_compute_forward_ssm_scan_f32(
8337
8580
  const ggml_compute_params * params,
8338
8581
  ggml_tensor * dst) {
8339
- const ggml_tensor * src0 = dst->src[0]; // s
8340
- const ggml_tensor * src1 = dst->src[1]; // x
8341
- const ggml_tensor * src2 = dst->src[2]; // dt
8342
- const ggml_tensor * src3 = dst->src[3]; // A
8343
- const ggml_tensor * src4 = dst->src[4]; // B
8344
- const ggml_tensor * src5 = dst->src[5]; // C
8582
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8583
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8584
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8585
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8586
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8587
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8588
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8345
8589
 
8346
8590
  const int ith = params->ith;
8347
8591
  const int nth = params->nth;
8348
8592
 
8349
- const int64_t nc = src0->ne[0]; // d_state
8350
- const int64_t nr = src0->ne[1]; // d_inner
8351
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
8352
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
8593
+ const int64_t nc = src0->ne[0]; // d_state
8594
+ const int64_t nr = src0->ne[1]; // dim
8595
+ const int64_t nh = src1->ne[1]; // n_head
8596
+ const int64_t ng = src4->ne[1];
8597
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8598
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8599
+
8600
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
8601
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
8353
8602
 
8354
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
8603
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8355
8604
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8356
8605
  GGML_ASSERT(src1->nb[0] == sizeof(float));
8357
8606
  GGML_ASSERT(src2->nb[0] == sizeof(float));
8358
8607
  GGML_ASSERT(src3->nb[0] == sizeof(float));
8359
8608
  GGML_ASSERT(src4->nb[0] == sizeof(float));
8360
8609
  GGML_ASSERT(src5->nb[0] == sizeof(float));
8361
- // required for the dot product between s and C
8362
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8363
- // required for per-sequence offsets for states
8364
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
8365
- // required to get correct offset for state destination (i.e. src1->nb[3])
8366
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8610
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8611
+ GGML_ASSERT(nh % ng == 0);
8367
8612
 
8368
- // rows per thread
8369
- const int dr = (nr + nth - 1)/nth;
8613
+ // heads per thread
8614
+ const int dh = (nh + nth - 1)/nth;
8370
8615
 
8371
- // row range for this thread
8372
- const int ir0 = dr*ith;
8373
- const int ir1 = MIN(ir0 + dr, nr);
8374
- const int ir = ir1 - ir0;
8616
+ // head range for this thread
8617
+ const int ih0 = dh*ith;
8618
+ const int ih1 = MIN(ih0 + dh, nh);
8619
+
8620
+ const int32_t * ids = (const int32_t *) src6->data;
8621
+
8622
+ for (int i3 = 0; i3 < ns; ++i3) {
8623
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8624
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8625
+
8626
+ for (int i2 = 0; i2 < nt; ++i2) {
8627
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8628
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8629
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8630
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8631
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8632
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8633
+
8634
+ if (src3->ne[0] == 1) {
8635
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8636
+
8637
+ // n_head
8638
+ for (int h = ih0; h < ih1; ++h) {
8639
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8641
+ const float dA = expf(dt_soft_plus * A[h]);
8642
+ const int g = h / (nh / ng); // repeat_interleave
8643
+
8644
+ // dim
8645
+ for (int i1 = 0; i1 < nr; ++i1) {
8646
+ const int ii = i1 + h*nr;
8647
+ const float x_dt = x[ii] * dt_soft_plus;
8648
+ float sumf = 0.0f;
8649
+ #if defined(GGML_SIMD)
8650
+ #if defined(__ARM_FEATURE_SVE)
8651
+ const int ggml_f32_epr = svcntw();
8652
+ const int ggml_f32_step = 1 * ggml_f32_epr;
8653
+
8654
+ const int np = (nc & ~(ggml_f32_step - 1));
8655
+
8656
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8657
+
8658
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8659
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8660
+
8661
+ for (int i = 0; i < np; i += ggml_f32_step) {
8662
+ // TODO: maybe unroll more?
8663
+ for (int j = 0; j < 1; j++) {
8664
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8665
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
8666
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8667
+
8668
+ t0 = GGML_F32_VEC_MUL(t0, adA);
8669
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
8670
+
8671
+ t0 = GGML_F32_VEC_ADD(t0, t1);
8672
+
8673
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
8674
+
8675
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8676
+ }
8677
+ }
8678
+
8679
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
8680
+ #elif defined(__riscv_v_intrinsic)
8681
+ // todo: RVV implementation
8682
+ const int np = 0;
8683
+ #else
8684
+ const int np = (nc & ~(GGML_F32_STEP - 1));
8685
+
8686
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8687
+
8688
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8689
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8690
+
8691
+ GGML_F32_VEC ax[GGML_F32_ARR];
8692
+ GGML_F32_VEC ay[GGML_F32_ARR];
8693
+ GGML_F32_VEC az[GGML_F32_ARR];
8694
+
8695
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
8696
+ for (int j = 0; j < GGML_F32_ARR; j++) {
8697
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8698
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
8699
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
8700
+
8701
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8702
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8703
+
8704
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8705
+
8706
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8707
+
8708
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8709
+ }
8710
+ }
8375
8711
 
8376
- #ifdef __ARM_FEATURE_SVE
8377
- for (int i3 = 0; i3 < n_s; ++i3) {
8378
- for (int i2 = 0; i2 < n_t; ++i2) {
8379
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8380
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8381
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8382
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8383
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8384
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8385
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8386
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8387
-
8388
- // use the output as the source for the next token-wise iterations
8389
- if (i2 > 0) { s0 = s; }
8390
-
8391
- // d_inner
8392
- for (int i1 = 0; i1 < ir; ++i1) {
8393
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8394
- float x_dt = x[i1] * dt_soft_plus;
8395
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8396
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8397
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8398
-
8399
- for (int64_t k = 0; k < nc; k += svcntw()) {
8400
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
8401
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
8402
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
8403
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
8404
-
8405
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8406
- t1 = exp_ps_sve(svptrue_b32(), t1);
8407
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8408
-
8409
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
8410
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8411
-
8412
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8712
+ // reduce sum0..sum3 to sum0
8713
+ GGML_F32_VEC_REDUCE(sumf, sum);
8714
+ #endif
8715
+ #else
8716
+ const int np = 0;
8717
+ #endif
8718
+ // d_state
8719
+ for (int i0 = np; i0 < nc; ++i0) {
8720
+ const int i = i0 + ii*nc;
8721
+ const int ig = i0 + g*nc;
8722
+ // state = prev_state * dA + dB * x
8723
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8724
+ // y = rowwise_dotprod(state, C)
8725
+ sumf += state * C[ig];
8726
+ s[i] = state;
8727
+ }
8728
+ y[ii] = sumf;
8413
8729
  }
8414
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
8415
8730
  }
8416
- }
8417
- }
8418
- #else
8419
- for (int i3 = 0; i3 < n_s; ++i3) {
8420
- for (int i2 = 0; i2 < n_t; ++i2) {
8421
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8422
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8423
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8424
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8425
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8426
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8427
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8428
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8429
-
8430
- // use the output as the source for the next token-wise iterations
8431
- if (i2 > 0) { s0 = s; }
8432
-
8433
- // d_inner
8434
- for (int i1 = 0; i1 < ir; ++i1) {
8435
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
8436
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8437
- float x_dt = x[i1] * dt_soft_plus;
8438
- float sumf = 0.0f;
8439
- // d_state
8440
- for (int i0 = 0; i0 < nc; ++i0) {
8441
- int i = i0 + i1*nc;
8442
- // state = prev_state * dA + dB * x
8443
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
8444
- // y = rowwise_dotprod(state, C)
8445
- sumf += state * C[i0];
8446
- s[i] = state;
8731
+ } else {
8732
+ // Mamba-1 has an element-wise decay factor for the states
8733
+
8734
+ // n_head
8735
+ for (int h = ih0; h < ih1; ++h) {
8736
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8738
+ const int g = h / (nh / ng); // repeat_interleave
8739
+
8740
+ // dim
8741
+ for (int i1 = 0; i1 < nr; ++i1) {
8742
+ const int ii = i1 + h*nr;
8743
+ const float x_dt = x[ii] * dt_soft_plus;
8744
+ #if defined(__ARM_FEATURE_SVE)
8745
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8746
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8747
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8748
+
8749
+ // d_state
8750
+ // TODO: what happens when (d_state % svcntw()) != 0?
8751
+ for (int64_t k = 0; k < nc; k += svcntw()) {
8752
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8753
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
8754
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8755
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8756
+
8757
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8758
+ t1 = exp_ps_sve(svptrue_b32(), t1);
8759
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8760
+
8761
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8762
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8763
+
8764
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8765
+ }
8766
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8767
+ #else
8768
+ float sumf = 0.0f;
8769
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8770
+ // and also because expf is used within the loop.
8771
+ // d_state
8772
+ for (int i0 = 0; i0 < nc; ++i0) {
8773
+ const int i = i0 + ii*nc;
8774
+ const int ig = i0 + g*nc;
8775
+ // state = prev_state * dA + dB * x
8776
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8777
+ // y = rowwise_dotprod(state, C)
8778
+ sumf += state * C[ig];
8779
+ s[i] = state;
8780
+ }
8781
+ y[ii] = sumf;
8782
+ #endif
8447
8783
  }
8448
- y[i1] = sumf;
8449
8784
  }
8450
8785
  }
8786
+ // use the output as the source when it's not the first token-wise iteration
8787
+ s0 = s;
8451
8788
  }
8452
- #endif
8789
+ }
8453
8790
  }
8454
8791
 
8455
8792
  void ggml_compute_forward_ssm_scan(
@@ -8688,6 +9025,18 @@ void ggml_compute_forward_glu(
8688
9025
  {
8689
9026
  ggml_compute_forward_swiglu(params, dst);
8690
9027
  } break;
9028
+ case GGML_GLU_OP_SWIGLU_OAI:
9029
+ {
9030
+ ggml_compute_forward_swiglu_oai(params, dst);
9031
+ } break;
9032
+ case GGML_GLU_OP_GEGLU_ERF:
9033
+ {
9034
+ ggml_compute_forward_geglu_erf(params, dst);
9035
+ } break;
9036
+ case GGML_GLU_OP_GEGLU_QUICK:
9037
+ {
9038
+ ggml_compute_forward_geglu_quick(params, dst);
9039
+ } break;
8691
9040
  default:
8692
9041
  {
8693
9042
  GGML_ABORT("fatal error");
@@ -9283,8 +9632,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
9283
9632
  int64_t h_stride_2d = head_size * head_size;
9284
9633
 
9285
9634
  #if defined(GGML_SIMD)
9286
- #if defined(__ARM_FEATURE_SVE)
9287
- // scalar Route to scalar implementation //TODO: Write SVE code
9635
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9636
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9288
9637
  for (int64_t t = 0; t < T; t++) {
9289
9638
  int64_t t_offset = t * t_stride;
9290
9639
  int64_t state_offset = head_size * C * (t / (T / n_seqs));
@@ -9732,6 +10081,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
9732
10081
  const int ir1 = MIN(ir0 + dr, nr);
9733
10082
 
9734
10083
  const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10084
+
9735
10085
  const float alpha = adamw_params_ptr[0];
9736
10086
  const float beta1 = adamw_params_ptr[1];
9737
10087
  const float beta2 = adamw_params_ptr[2];
@@ -9739,7 +10089,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
9739
10089
  const float wd = adamw_params_ptr[4];
9740
10090
  const float beta1h = adamw_params_ptr[5];
9741
10091
  const float beta2h = adamw_params_ptr[6];
9742
-
10092
+ const float keep = 1.f - alpha * wd;
9743
10093
  for (int ir = ir0; ir < ir1; ++ir) {
9744
10094
  const int64_t i03 = ir/(ne02*ne01);
9745
10095
  const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -9762,7 +10112,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
9762
10112
  // The weight decay is applied independently of the Adam momenta m and v.
9763
10113
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
9764
10114
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
9765
- w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
10115
+ w[i00] = w[i00] * keep - alpha * mh / vh;
9766
10116
  }
9767
10117
  }
9768
10118
  }
@@ -9784,3 +10134,63 @@ void ggml_compute_forward_opt_step_adamw(
9784
10134
  }
9785
10135
  }
9786
10136
  }
10137
+
10138
+ static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10139
+ const ggml_tensor * src0 = dst->src[0];
10140
+ const ggml_tensor * src0_grad = dst->src[1];
10141
+ const ggml_tensor * sgd_params = dst->src[2];
10142
+
10143
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10144
+ GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10145
+
10146
+ const int ith = params->ith;
10147
+ const int nth = params->nth;
10148
+
10149
+ const int nr = ggml_nrows(src0);
10150
+
10151
+ GGML_TENSOR_UNARY_OP_LOCALS
10152
+ GGML_ASSERT(nb00 == sizeof(float));
10153
+
10154
+ // rows per thread
10155
+ const int dr = (nr + nth - 1) / nth;
10156
+
10157
+ // row range for this thread
10158
+ const int ir0 = dr * ith;
10159
+ const int ir1 = MIN(ir0 + dr, nr);
10160
+
10161
+ // using adamw param subset we care about - alpha, wd - could have a separate struct
10162
+ const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
10163
+ const float alpha = sgd_params_ptr[0];
10164
+ const float keep = 1.f - alpha * sgd_params_ptr[1];
10165
+
10166
+ for (int ir = ir0; ir < ir1; ++ir) {
10167
+ const int64_t i03 = ir / (ne02 * ne01);
10168
+ const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10169
+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10170
+
10171
+ const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10172
+
10173
+ float * w = (float *) ((char *) src0->data + offset); // weight
10174
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10175
+
10176
+ for (int i00 = 0; i00 < ne00; ++i00) {
10177
+ w[i00] = w[i00] * keep - alpha * g[i00];
10178
+ }
10179
+ }
10180
+ }
10181
+
10182
+ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10183
+ const ggml_tensor * src0 = dst->src[0];
10184
+
10185
+ switch (src0->type) {
10186
+ case GGML_TYPE_F32:
10187
+ {
10188
+ ggml_compute_forward_opt_step_sgd_f32(params, dst);
10189
+ }
10190
+ break;
10191
+ default:
10192
+ {
10193
+ GGML_ABORT("fatal error - sgd is F32 only");
10194
+ }
10195
+ }
10196
+ }