whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -69,6 +69,10 @@
69
69
  #define VECTOR_REGISTERS 16
70
70
  #endif
71
71
 
72
+ #if defined(__riscv_v_intrinsic)
73
+ #define LMUL 4
74
+ #endif
75
+
72
76
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
73
77
 
74
78
  namespace {
@@ -117,8 +121,7 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
117
121
  #endif
118
122
 
119
123
  #if defined(__MMA__)
120
- typedef vector unsigned char vec_t;
121
- typedef __vector_quad acc_t;
124
+ #include "sgemm-ppc.h"
122
125
  #endif
123
126
  ////////////////////////////////////////////////////////////////////////////////////////////////////
124
127
  // VECTORIZED FUSED MULTIPLY ADD
@@ -176,6 +179,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
176
179
  }
177
180
  #endif
178
181
 
182
+ #if defined(__riscv_zvfh)
183
+ template <>
184
+ inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
185
+ return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
186
+ }
187
+ inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
188
+ return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
189
+ }
190
+ inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
191
+ return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
192
+ }
193
+ inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
194
+ return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
195
+ }
196
+ inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
197
+ return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
198
+ }
199
+ inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
200
+ return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
201
+ }
202
+ inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
203
+ return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
204
+ }
205
+ inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
206
+ return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
207
+ }
208
+ #endif
209
+
210
+ #if defined(__riscv_zvfbfwma)
211
+ inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
212
+ return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
213
+ }
214
+ inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
215
+ return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
216
+ }
217
+ inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
218
+ return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
219
+ }
220
+ #endif
221
+
179
222
  ////////////////////////////////////////////////////////////////////////////////////////////////////
180
223
  // VECTORIZED HORIZONTAL SUM
181
224
 
@@ -228,6 +271,25 @@ inline float hsum(__m512 x) {
228
271
  }
229
272
  #endif // __AVX512F__
230
273
 
274
+ #if defined(__riscv_zvfh)
275
+ inline float hsum(vfloat32m1_t x) {
276
+ return __riscv_vfmv_f_s_f32m1_f32(
277
+ __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
278
+ }
279
+ inline float hsum(vfloat32m2_t x) {
280
+ return __riscv_vfmv_f_s_f32m1_f32(
281
+ __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
282
+ }
283
+ inline float hsum(vfloat32m4_t x) {
284
+ return __riscv_vfmv_f_s_f32m1_f32(
285
+ __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
286
+ }
287
+ inline float hsum(vfloat32m8_t x) {
288
+ return __riscv_vfmv_f_s_f32m1_f32(
289
+ __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
290
+ }
291
+ #endif
292
+
231
293
  ////////////////////////////////////////////////////////////////////////////////////////////////////
232
294
  // VECTORIZED MEMORY LOADING
233
295
 
@@ -316,6 +378,88 @@ template <> inline __m256bh load(const float *p) {
316
378
  }
317
379
  #endif
318
380
 
381
+ #if defined(__riscv_zvfh)
382
+ template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
383
+ return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
384
+ }
385
+ template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
386
+ return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
387
+ }
388
+ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
389
+ return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
390
+ }
391
+ template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
392
+ return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
393
+ }
394
+ template <> inline vfloat32m1_t load(const float *p) {
395
+ return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
396
+ }
397
+ template <> inline vfloat32m2_t load(const float *p) {
398
+ return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
399
+ }
400
+ template <> inline vfloat32m4_t load(const float *p) {
401
+ return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
402
+ }
403
+ template <> inline vfloat32m8_t load(const float *p) {
404
+ return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
405
+ }
406
+ #endif
407
+
408
+ #if defined(__riscv_zvfbfwma)
409
+ template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
410
+ return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
411
+ }
412
+ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
413
+ return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
414
+ }
415
+ template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
416
+ return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
417
+ }
418
+ #endif
419
+
420
+ #if defined(__riscv_zvfh)
421
+ template <typename T> T set_zero();
422
+
423
+ template <> inline vfloat16mf2_t set_zero() {
424
+ return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
425
+ }
426
+ template <> inline vfloat16m1_t set_zero() {
427
+ return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
428
+ }
429
+ template <> inline vfloat16m2_t set_zero() {
430
+ return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
431
+ }
432
+ template <> inline vfloat16m4_t set_zero() {
433
+ return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
434
+ }
435
+ template <> inline vfloat32m1_t set_zero() {
436
+ return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
437
+ }
438
+ template <> inline vfloat32m2_t set_zero() {
439
+ return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
440
+ }
441
+ template <> inline vfloat32m4_t set_zero() {
442
+ return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
443
+ }
444
+ template <> inline vfloat32m8_t set_zero() {
445
+ return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
446
+ }
447
+ #endif
448
+
449
+ #if defined(__riscv_v_intrinsic)
450
+ template <typename T> size_t vlmax() {
451
+ if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
452
+ else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
453
+ else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
454
+ else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
455
+ else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
456
+ else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
457
+ else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
458
+ else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
459
+ return 0;
460
+ }
461
+ #endif
462
+
319
463
  ////////////////////////////////////////////////////////////////////////////////////////////////////
320
464
  // FLOATING POINT MATRIX MULTIPLICATION
321
465
 
@@ -489,6 +633,573 @@ class tinyBLAS {
489
633
  const int64_t ldc;
490
634
  };
491
635
 
636
+ #if defined(__riscv_v_intrinsic)
637
+ template <typename D, typename V, typename TA, typename TB, typename TC>
638
+ class tinyBLAS_RVV {
639
+ public:
640
+ tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
641
+ const TA *A, int64_t lda,
642
+ const TB *B, int64_t ldb,
643
+ TC *C, int64_t ldc)
644
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
645
+ }
646
+
647
+ bool matmul(int64_t m, int64_t n) {
648
+ if (k % vlmax<V>() != 0) {
649
+ return false;
650
+ }
651
+
652
+ #if LMUL == 1
653
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
654
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
655
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
656
+ return true;
657
+ }
658
+ if (m % 8 == 0 ) {
659
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
660
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
661
+ return true;
662
+ }
663
+ if (m % 4 == 0) {
664
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
665
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
666
+ return true;
667
+ }
668
+ #elif LMUL == 2
669
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
670
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
671
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
672
+ return true;
673
+ }
674
+ if (m % 8 == 0 ) {
675
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
676
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
677
+ return true;
678
+ }
679
+ if (m % 4 == 0) {
680
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
681
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
682
+ return true;
683
+ }
684
+ #else // LMUL = 4
685
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
686
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
687
+ mnpack<2, 2, 8>(m, n, SIZE_N, 36);
688
+ return true;
689
+ }
690
+ if (m % 8 == 0 ) {
691
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
692
+ mnpack<2, 2, 4>(m, n, SIZE_N, 36);
693
+ return true;
694
+ }
695
+ if (m % 4 == 0) {
696
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
697
+ mnpack<2, 2, 2>(m, n, SIZE_N, 36);
698
+ return true;
699
+ }
700
+ #endif
701
+ return false;
702
+ }
703
+
704
+ private:
705
+ template<int RM, int RN, int BM>
706
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
707
+ if (SIZE_N == RN) {
708
+ return gemm<RM, RN, BM>(m, n, BN);
709
+ }
710
+ if constexpr (RN > 1) {
711
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
712
+ } else {
713
+ GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
714
+ GGML_ASSERT(false); // we have miss something.
715
+ }
716
+ }
717
+
718
+ inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
719
+ size_t vl = vlmax<V>();
720
+ D Cv00 = set_zero<D>();
721
+ D Cv01 = set_zero<D>();
722
+ D Cv02 = set_zero<D>();
723
+ D Cv03 = set_zero<D>();
724
+ D Cv10 = set_zero<D>();
725
+ D Cv11 = set_zero<D>();
726
+ D Cv12 = set_zero<D>();
727
+ D Cv13 = set_zero<D>();
728
+ D Cv20 = set_zero<D>();
729
+ D Cv21 = set_zero<D>();
730
+ D Cv22 = set_zero<D>();
731
+ D Cv23 = set_zero<D>();
732
+ D Cv30 = set_zero<D>();
733
+ D Cv31 = set_zero<D>();
734
+ D Cv32 = set_zero<D>();
735
+ D Cv33 = set_zero<D>();
736
+ D Cv40 = set_zero<D>();
737
+ D Cv41 = set_zero<D>();
738
+ D Cv42 = set_zero<D>();
739
+ D Cv43 = set_zero<D>();
740
+ D Cv50 = set_zero<D>();
741
+ D Cv51 = set_zero<D>();
742
+ D Cv52 = set_zero<D>();
743
+ D Cv53 = set_zero<D>();
744
+
745
+ for (int64_t l = 0; l < k; l += vl) {
746
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
747
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
748
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
749
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
750
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
751
+ V Bv5 = load<V>(B + ldb * (jj + 5) + l);
752
+
753
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
754
+ Cv00 = madd(Av0, Bv0, Cv00);
755
+ Cv10 = madd(Av0, Bv1, Cv10);
756
+ Cv20 = madd(Av0, Bv2, Cv20);
757
+ Cv30 = madd(Av0, Bv3, Cv30);
758
+ Cv40 = madd(Av0, Bv4, Cv40);
759
+ Cv50 = madd(Av0, Bv5, Cv50);
760
+
761
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
762
+ Cv01 = madd(Av1, Bv0, Cv01);
763
+ Cv11 = madd(Av1, Bv1, Cv11);
764
+ Cv21 = madd(Av1, Bv2, Cv21);
765
+ Cv31 = madd(Av1, Bv3, Cv31);
766
+ Cv41 = madd(Av1, Bv4, Cv41);
767
+ Cv51 = madd(Av1, Bv5, Cv51);
768
+
769
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
770
+ Cv02 = madd(Av2, Bv0, Cv02);
771
+ Cv12 = madd(Av2, Bv1, Cv12);
772
+ Cv22 = madd(Av2, Bv2, Cv22);
773
+ Cv32 = madd(Av2, Bv3, Cv32);
774
+ Cv42 = madd(Av2, Bv4, Cv42);
775
+ Cv52 = madd(Av2, Bv5, Cv52);
776
+
777
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
778
+ Cv03 = madd(Av3, Bv0, Cv03);
779
+ Cv13 = madd(Av3, Bv1, Cv13);
780
+ Cv23 = madd(Av3, Bv2, Cv23);
781
+ Cv33 = madd(Av3, Bv3, Cv33);
782
+ Cv43 = madd(Av3, Bv4, Cv43);
783
+ Cv53 = madd(Av3, Bv5, Cv53);
784
+ }
785
+
786
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
787
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
788
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
789
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
790
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
791
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
792
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
793
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
794
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
795
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
796
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
797
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
798
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
799
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
800
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
801
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
802
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
803
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
804
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
805
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
806
+ C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
807
+ C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
808
+ C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
809
+ C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
810
+ }
811
+
812
+ inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
813
+ size_t vl = vlmax<V>();
814
+ D Cv00 = set_zero<D>();
815
+ D Cv01 = set_zero<D>();
816
+ D Cv02 = set_zero<D>();
817
+ D Cv03 = set_zero<D>();
818
+ D Cv10 = set_zero<D>();
819
+ D Cv11 = set_zero<D>();
820
+ D Cv12 = set_zero<D>();
821
+ D Cv13 = set_zero<D>();
822
+ D Cv20 = set_zero<D>();
823
+ D Cv21 = set_zero<D>();
824
+ D Cv22 = set_zero<D>();
825
+ D Cv23 = set_zero<D>();
826
+ D Cv30 = set_zero<D>();
827
+ D Cv31 = set_zero<D>();
828
+ D Cv32 = set_zero<D>();
829
+ D Cv33 = set_zero<D>();
830
+ D Cv40 = set_zero<D>();
831
+ D Cv41 = set_zero<D>();
832
+ D Cv42 = set_zero<D>();
833
+ D Cv43 = set_zero<D>();
834
+
835
+ for (int64_t l = 0; l < k; l += vl) {
836
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
837
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
838
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
839
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
840
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
841
+
842
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
843
+ Cv00 = madd(Av0, Bv0, Cv00);
844
+ Cv10 = madd(Av0, Bv1, Cv10);
845
+ Cv20 = madd(Av0, Bv2, Cv20);
846
+ Cv30 = madd(Av0, Bv3, Cv30);
847
+ Cv40 = madd(Av0, Bv4, Cv40);
848
+
849
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
850
+ Cv01 = madd(Av1, Bv0, Cv01);
851
+ Cv11 = madd(Av1, Bv1, Cv11);
852
+ Cv21 = madd(Av1, Bv2, Cv21);
853
+ Cv31 = madd(Av1, Bv3, Cv31);
854
+ Cv41 = madd(Av1, Bv4, Cv41);
855
+
856
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
857
+ Cv02 = madd(Av2, Bv0, Cv02);
858
+ Cv12 = madd(Av2, Bv1, Cv12);
859
+ Cv22 = madd(Av2, Bv2, Cv22);
860
+ Cv32 = madd(Av2, Bv3, Cv32);
861
+ Cv42 = madd(Av2, Bv4, Cv42);
862
+
863
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
864
+ Cv03 = madd(Av3, Bv0, Cv03);
865
+ Cv13 = madd(Av3, Bv1, Cv13);
866
+ Cv23 = madd(Av3, Bv2, Cv23);
867
+ Cv33 = madd(Av3, Bv3, Cv33);
868
+ Cv43 = madd(Av3, Bv4, Cv43);
869
+ }
870
+
871
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
872
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
873
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
874
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
875
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
876
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
877
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
878
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
879
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
880
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
881
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
882
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
883
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
884
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
885
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
886
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
887
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
888
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
889
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
890
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
891
+ }
892
+
893
+ inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
894
+ size_t vl = vlmax<V>();
895
+ D Cv00 = set_zero<D>();
896
+ D Cv01 = set_zero<D>();
897
+ D Cv02 = set_zero<D>();
898
+ D Cv03 = set_zero<D>();
899
+ D Cv10 = set_zero<D>();
900
+ D Cv11 = set_zero<D>();
901
+ D Cv12 = set_zero<D>();
902
+ D Cv13 = set_zero<D>();
903
+ D Cv20 = set_zero<D>();
904
+ D Cv21 = set_zero<D>();
905
+ D Cv22 = set_zero<D>();
906
+ D Cv23 = set_zero<D>();
907
+ D Cv30 = set_zero<D>();
908
+ D Cv31 = set_zero<D>();
909
+ D Cv32 = set_zero<D>();
910
+ D Cv33 = set_zero<D>();
911
+
912
+ for (int64_t l = 0; l < k; l += vl) {
913
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
914
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
915
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
916
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
917
+
918
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
919
+ Cv00 = madd(Av0, Bv0, Cv00);
920
+ Cv01 = madd(Av1, Bv0, Cv01);
921
+ Cv02 = madd(Av2, Bv0, Cv02);
922
+ Cv03 = madd(Av3, Bv0, Cv03);
923
+
924
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
925
+ Cv10 = madd(Av0, Bv1, Cv10);
926
+ Cv11 = madd(Av1, Bv1, Cv11);
927
+ Cv12 = madd(Av2, Bv1, Cv12);
928
+ Cv13 = madd(Av3, Bv1, Cv13);
929
+
930
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
931
+ Cv20 = madd(Av0, Bv2, Cv20);
932
+ Cv21 = madd(Av1, Bv2, Cv21);
933
+ Cv22 = madd(Av2, Bv2, Cv22);
934
+ Cv23 = madd(Av3, Bv2, Cv23);
935
+
936
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
937
+ Cv30 = madd(Av0, Bv3, Cv30);
938
+ Cv31 = madd(Av1, Bv3, Cv31);
939
+ Cv32 = madd(Av2, Bv3, Cv32);
940
+ Cv33 = madd(Av3, Bv3, Cv33);
941
+ }
942
+
943
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
944
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
945
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
946
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
947
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
948
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
949
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
950
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
951
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
952
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
953
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
954
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
955
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
956
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
957
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
958
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
959
+ }
960
+
961
+ inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
962
+ size_t vl = vlmax<V>();
963
+ D Cv00 = set_zero<D>();
964
+ D Cv01 = set_zero<D>();
965
+ D Cv02 = set_zero<D>();
966
+ D Cv03 = set_zero<D>();
967
+ D Cv10 = set_zero<D>();
968
+ D Cv11 = set_zero<D>();
969
+ D Cv12 = set_zero<D>();
970
+ D Cv13 = set_zero<D>();
971
+ D Cv20 = set_zero<D>();
972
+ D Cv21 = set_zero<D>();
973
+ D Cv22 = set_zero<D>();
974
+ D Cv23 = set_zero<D>();
975
+
976
+ for (int64_t l = 0; l < k; l += vl) {
977
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
978
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
979
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
980
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
981
+
982
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
983
+ Cv00 = madd(Av0, Bv0, Cv00);
984
+ Cv01 = madd(Av1, Bv0, Cv01);
985
+ Cv02 = madd(Av2, Bv0, Cv02);
986
+ Cv03 = madd(Av3, Bv0, Cv03);
987
+
988
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
989
+ Cv10 = madd(Av0, Bv1, Cv10);
990
+ Cv11 = madd(Av1, Bv1, Cv11);
991
+ Cv12 = madd(Av2, Bv1, Cv12);
992
+ Cv13 = madd(Av3, Bv1, Cv13);
993
+
994
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
995
+ Cv20 = madd(Av0, Bv2, Cv20);
996
+ Cv21 = madd(Av1, Bv2, Cv21);
997
+ Cv22 = madd(Av2, Bv2, Cv22);
998
+ Cv23 = madd(Av3, Bv2, Cv23);
999
+ }
1000
+
1001
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1002
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1003
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1004
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1005
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1006
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1007
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1008
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1009
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
1010
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
1011
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
1012
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
1013
+ }
1014
+
1015
+ inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
1016
+ size_t vl = vlmax<V>();
1017
+ D Cv00 = set_zero<D>();
1018
+ D Cv01 = set_zero<D>();
1019
+ D Cv02 = set_zero<D>();
1020
+ D Cv03 = set_zero<D>();
1021
+ D Cv10 = set_zero<D>();
1022
+ D Cv11 = set_zero<D>();
1023
+ D Cv12 = set_zero<D>();
1024
+ D Cv13 = set_zero<D>();
1025
+
1026
+ for (int64_t l = 0; l < k; l += vl) {
1027
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1028
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1029
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1030
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1031
+
1032
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1033
+ Cv00 = madd(Av0, Bv0, Cv00);
1034
+ Cv01 = madd(Av1, Bv0, Cv01);
1035
+ Cv02 = madd(Av2, Bv0, Cv02);
1036
+ Cv03 = madd(Av3, Bv0, Cv03);
1037
+
1038
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1039
+ Cv10 = madd(Av0, Bv1, Cv10);
1040
+ Cv11 = madd(Av1, Bv1, Cv11);
1041
+ Cv12 = madd(Av2, Bv1, Cv12);
1042
+ Cv13 = madd(Av3, Bv1, Cv13);
1043
+ }
1044
+
1045
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1046
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1047
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1048
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1049
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1050
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1051
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1052
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1053
+ }
1054
+
1055
+ inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
1056
+ size_t vl = vlmax<V>();
1057
+ D Cv00 = set_zero<D>();
1058
+ D Cv01 = set_zero<D>();
1059
+ D Cv02 = set_zero<D>();
1060
+ D Cv03 = set_zero<D>();
1061
+
1062
+ for (int64_t l = 0; l < k; l += vl) {
1063
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1064
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1065
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1066
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1067
+
1068
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1069
+ Cv00 = madd(Av0, Bv0, Cv00);
1070
+ Cv01 = madd(Av1, Bv0, Cv01);
1071
+ Cv02 = madd(Av2, Bv0, Cv02);
1072
+ Cv03 = madd(Av3, Bv0, Cv03);
1073
+ }
1074
+
1075
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1076
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1077
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1078
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1079
+ }
1080
+
1081
+ inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
1082
+ size_t vl = vlmax<V>();
1083
+ D Cv00 = set_zero<D>();
1084
+ D Cv01 = set_zero<D>();
1085
+ D Cv10 = set_zero<D>();
1086
+ D Cv11 = set_zero<D>();
1087
+
1088
+ for (int64_t l = 0; l < k; l += vl) {
1089
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1090
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1091
+
1092
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1093
+ Cv00 = madd(Av0, Bv0, Cv00);
1094
+ Cv01 = madd(Av1, Bv0, Cv01);
1095
+
1096
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1097
+ Cv10 = madd(Av0, Bv1, Cv10);
1098
+ Cv11 = madd(Av1, Bv1, Cv11);
1099
+ }
1100
+
1101
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1102
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1103
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1104
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1105
+ }
1106
+
1107
+ inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
1108
+ size_t vl = vlmax<V>();
1109
+ D Cv00 = set_zero<D>();
1110
+ D Cv01 = set_zero<D>();
1111
+
1112
+ for (int64_t l = 0; l < k; l += vl) {
1113
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1114
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1115
+
1116
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1117
+ Cv00 = madd(Av0, Bv0, Cv00);
1118
+ Cv01 = madd(Av1, Bv0, Cv01);
1119
+ }
1120
+
1121
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1122
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1123
+ }
1124
+
1125
+ template <int RM, int RN>
1126
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
1127
+ if constexpr (RM == 4) {
1128
+ if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
1129
+ if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
1130
+ if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
1131
+ if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
1132
+ if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
1133
+ if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
1134
+ } else if constexpr (RM == 2) {
1135
+ if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
1136
+ if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
1137
+ }
1138
+ }
1139
+
1140
+ template <int RM, int RN, int BM>
1141
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
1142
+ GGML_ASSERT(m % (RM * BM) == 0);
1143
+ const int64_t ytiles = m / (RM * BM);
1144
+ const int64_t xtiles = (n + RN -1) / RN;
1145
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
1146
+
1147
+ // "round" bloc_size to "nearest" BN
1148
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
1149
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
1150
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
1151
+ const int64_t nb_job = ytiles * NB_BN;
1152
+
1153
+ if (params->ith == 0) {
1154
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
1155
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1156
+ ggml_threadpool_chunk_set(params->threadpool, params->nth);
1157
+ }
1158
+
1159
+ ggml_barrier(params->threadpool);
1160
+
1161
+ int64_t job = params->ith;
1162
+ while (job < nb_job) {
1163
+ const int64_t ii = (job % ytiles) * RM * BM;
1164
+ const int64_t jb = job / ytiles;
1165
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
1166
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
1167
+
1168
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
1169
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
1170
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
1171
+
1172
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
1173
+ int64_t jj = jj0;
1174
+ for (; jj < jj1; jj += RN) {
1175
+ gemm_bloc<RM, RN>(ii + bi, jj);
1176
+ }
1177
+ if constexpr (RN > 1) {
1178
+ for (; jj < jj2; jj += RN - 1) {
1179
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
1180
+ }
1181
+ }
1182
+ GGML_ASSERT(jj == jj2);
1183
+ }
1184
+
1185
+ job = ggml_threadpool_chunk_add(params->threadpool, 1);
1186
+ }
1187
+
1188
+ ggml_barrier(params->threadpool);
1189
+ return;
1190
+ }
1191
+
1192
+ const ggml_compute_params * params;
1193
+ const TA *const A;
1194
+ const TB *const B;
1195
+ TC *const C;
1196
+ const int64_t k;
1197
+ const int64_t lda;
1198
+ const int64_t ldb;
1199
+ const int64_t ldc;
1200
+ };
1201
+ #endif
1202
+
492
1203
  //////////////////////////////////////////////////////////////////////////////////////////
493
1204
  // QUANT ZERO MATRIX MULTIPLICATION
494
1205
 
@@ -1573,95 +2284,35 @@ class tinyBLAS_BF16_PPC {
1573
2284
  const int nth;
1574
2285
  };
1575
2286
 
1576
- template <typename TA>
1577
- class tinyBLAS_Q0_PPC {
1578
- public:
1579
- tinyBLAS_Q0_PPC(int64_t k,
1580
- const TA *A, int64_t lda,
1581
- const block_q8_0 *B, int64_t ldb,
1582
- float *C, int64_t ldc,
1583
- int ith, int nth)
2287
+ template <typename TA>
2288
+ tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
2289
+ const TA *A, int64_t lda,
2290
+ const block_q8_0 *B, int64_t ldb,
2291
+ float *C, int64_t ldc,
2292
+ int ith, int nth)
1584
2293
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2294
+ kc = 64;
1585
2295
  }
1586
2296
 
1587
- void matmul(int64_t m, int64_t n) {
1588
- mnpack(0, m, 0, n);
1589
- }
1590
-
1591
- private:
1592
-
1593
- inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1594
- for (int I = 0; I < RM; I++) {
1595
- for (int J = 0; J < RN; J++) {
1596
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1597
- }
1598
- }
1599
- }
1600
-
1601
- template<int size>
1602
- inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1603
- vector signed int vec_C[4];
1604
- vector float CA[4] = {0};
1605
- vector float res[4] = {0};
1606
- __builtin_mma_disassemble_acc(vec_C, ACC);
1607
- for (int i = 0; i < 4; i++) {
1608
- CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1609
- res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1610
- fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611
- }
1612
- }
1613
- /* This function processes quantized data from block_q4_0 elements.
1614
- * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
1615
- * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
1616
- * Also compute the rowsum which is required to compensate the above conversion. */
1617
- inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
1618
- const vector signed char lowMask = vec_splats((signed char)0xF);
1619
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1620
- const vector signed char v8 = vec_splats((signed char)0x8);
1621
- vector signed int vsum = {0};
1622
- vector signed int vsum2 = {0};
1623
- c[0] = vec_and(c[1], lowMask);
1624
- c[1] = vec_sr(c[1], v4);
1625
- c[0] = vec_sub(c[0], v8);
1626
- c[1] = vec_sub(c[1], v8);
1627
- vsum = vec_sum4s(c[0], vsum);
1628
- vsum2 = vec_sum4s(c[1], vsum2);
1629
- vsum = vec_add(vsum, vsum2);
1630
- *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1631
- }
1632
-
1633
- template <typename V1, typename V2>
1634
- inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
1635
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1636
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1637
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1638
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1639
- V2 t1, t2, t3, t4, t5, t6, t7, t8;
1640
- vector unsigned char xor_vector;
1641
- uint8_t flip_vec = 0x80;
1642
- xor_vector = vec_splats(flip_vec);
1643
- t1 = vec_perm(s1, s2, swiz1);
1644
- t2 = vec_perm(s1, s2, swiz2);
1645
- t3 = vec_perm(s3, s4, swiz1);
1646
- t4 = vec_perm(s3, s4, swiz2);
1647
- t5 = vec_perm(t1, t3, swiz3);
1648
- t6 = vec_perm(t1, t3, swiz4);
1649
- t7 = vec_perm(t2, t4, swiz3);
1650
- t8 = vec_perm(t2, t4, swiz4);
1651
- if (flip == true) {
1652
- t5 = vec_xor(t5, xor_vector);
1653
- t6 = vec_xor(t6, xor_vector);
1654
- t7 = vec_xor(t7, xor_vector);
1655
- t8 = vec_xor(t8, xor_vector);
2297
+ template<typename TA>
2298
+ void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
2299
+ int mc = 64; int nc = 64;
2300
+ if (n % 8 == 0 && n < nc) {
2301
+ nc = n;
2302
+ mc = 32 ;
2303
+ kc = 32;
2304
+ }
2305
+ const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
2306
+ if (is_aligned) {
2307
+ this->matmul_tiled_q0(m, n, mc, nc, kc);
2308
+ } else {
2309
+ mnpack(0, m, 0, n);
1656
2310
  }
1657
- vec_xst(t5, 0, vecOffset);
1658
- vec_xst(t6, 0, vecOffset+16);
1659
- vec_xst(t7, 0, vecOffset+32);
1660
- vec_xst(t8, 0, vecOffset+48);
1661
2311
  }
1662
2312
 
1663
- template<int size>
1664
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
2313
+ template<typename TA>
2314
+ template<int size>
2315
+ void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1665
2316
  int64_t i, j;
1666
2317
  TA *aoffset = NULL;
1667
2318
  int8_t *vecOffset = NULL;
@@ -1781,8 +2432,10 @@ class tinyBLAS_Q0_PPC {
1781
2432
  }
1782
2433
  }
1783
2434
  }
2435
+
2436
+ template<typename TA>
1784
2437
  template<typename VA, typename VB>
1785
- void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2438
+ void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1786
2439
  int64_t i, j;
1787
2440
  block_q8_0 *aoffset = NULL;
1788
2441
  VA *vecOffset = NULL;
@@ -1822,7 +2475,6 @@ class tinyBLAS_Q0_PPC {
1822
2475
  j--;
1823
2476
  } while(j > 0);
1824
2477
  }
1825
-
1826
2478
  if (rows & 4) {
1827
2479
  aoffsets[0] = aoffset;
1828
2480
  for (int it = 1; it < 4; it++ )
@@ -1878,7 +2530,8 @@ class tinyBLAS_Q0_PPC {
1878
2530
  }
1879
2531
  }
1880
2532
 
1881
- void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2533
+ template<typename TA>
2534
+ void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1882
2535
  int m_rem = MIN(m - m0, 16);
1883
2536
  int n_rem = MIN(n - n0, 16);
1884
2537
 
@@ -1915,7 +2568,8 @@ class tinyBLAS_Q0_PPC {
1915
2568
  }
1916
2569
 
1917
2570
 
1918
- void KERNEL_4x8(int64_t ii, int64_t jj) {
2571
+ template<typename TA>
2572
+ void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
1919
2573
  vec_t vec_A[8], vec_B[16] = {0};
1920
2574
  acc_t acc_0, acc_1;
1921
2575
  std::array<int, 4> comparray {};
@@ -1953,14 +2607,15 @@ class tinyBLAS_Q0_PPC {
1953
2607
  aoffset += lda;
1954
2608
  }
1955
2609
  }
1956
- compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1957
- compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2610
+ compute(&acc_0, 0, 0, comparray, vs, fin_res);
2611
+ compute(&acc_1, 0, 4, comparray, vs, fin_res);
1958
2612
  }
1959
2613
  save_res(ii, jj, 0, fin_res);
1960
2614
  save_res(ii, jj+4, 4, fin_res);
1961
2615
  }
1962
2616
 
1963
- void KERNEL_8x4(int64_t ii, int64_t jj) {
2617
+ template<typename TA>
2618
+ void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
1964
2619
  vec_t vec_A[16], vec_B[8] = {0};
1965
2620
  acc_t acc_0, acc_1;
1966
2621
  std::array<int, 8> comparray {};
@@ -1997,16 +2652,18 @@ class tinyBLAS_Q0_PPC {
1997
2652
  aoffset += lda;
1998
2653
  }
1999
2654
  }
2000
- compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2001
- compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2655
+ compute(&acc_0, 0, 0, comparray, vs, fin_res);
2656
+ compute(&acc_1, 4, 4, comparray, vs, fin_res);
2002
2657
  }
2003
2658
  save_res(ii, jj, 0, fin_res);
2004
2659
  save_res(ii+4, jj, 4, fin_res);
2005
2660
  }
2006
2661
 
2007
- void KERNEL_8x8(int64_t ii, int64_t jj) {
2662
+ template<typename TA>
2663
+ void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
2008
2664
  vec_t vec_A[16], vec_B[16] = {0};
2009
2665
  acc_t acc_0, acc_1, acc_2, acc_3;
2666
+ acc_t acc_4, acc_5, acc_6, acc_7;
2010
2667
  std::array<int, 8> comparray {};
2011
2668
  vector float fin_res[16] = {0};
2012
2669
  vector float vs[16] = {0};
@@ -2046,10 +2703,10 @@ class tinyBLAS_Q0_PPC {
2046
2703
  aoffset += lda;
2047
2704
  }
2048
2705
  }
2049
- compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2050
- compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2051
- compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2052
- compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2706
+ compute(&acc_0, 0, 0, comparray, vs, fin_res);
2707
+ compute(&acc_1, 4, 4, comparray, vs, fin_res);
2708
+ compute(&acc_2, 0, 8, comparray, vs, fin_res);
2709
+ compute(&acc_3, 4, 12, comparray, vs, fin_res);
2053
2710
  }
2054
2711
  save_res(ii, jj, 0, fin_res);
2055
2712
  save_res(ii+4, jj, 4, fin_res);
@@ -2057,7 +2714,8 @@ class tinyBLAS_Q0_PPC {
2057
2714
  save_res(ii+4, jj+4, 12, fin_res);
2058
2715
  }
2059
2716
 
2060
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2717
+ template<typename TA>
2718
+ void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2061
2719
  int64_t ytiles = (m - m0) / RM;
2062
2720
  int64_t xtiles = (n - n0) / RN;
2063
2721
  int64_t tiles = xtiles * ytiles;
@@ -2125,21 +2783,9 @@ class tinyBLAS_Q0_PPC {
2125
2783
  }
2126
2784
  }
2127
2785
 
2128
- template<int RM, int RN>
2129
- inline void kernel(int64_t ii, int64_t jj) {
2130
- if constexpr(RM == 4 && RN == 8) {
2131
- KERNEL_4x8(ii,jj);
2132
- } else if constexpr(RM == 8 && RN == 4) {
2133
- KERNEL_8x4(ii,jj);
2134
- } else if constexpr(RM == 8 && RN == 8) {
2135
- KERNEL_8x8(ii,jj);
2136
- } else {
2137
- assert(false && "RN/RM values not supported");
2138
- }
2139
- }
2140
-
2786
+ template<typename TA>
2141
2787
  template <int RM, int RN>
2142
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2788
+ NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2143
2789
  int64_t ytiles = (m - m0) / RM;
2144
2790
  int64_t xtiles = (n - n0) / RN;
2145
2791
  int64_t tiles = xtiles * ytiles;
@@ -2151,20 +2797,12 @@ class tinyBLAS_Q0_PPC {
2151
2797
  for (int64_t job = start; job < end; ++job) {
2152
2798
  int64_t ii = m0 + job / xtiles * RM;
2153
2799
  int64_t jj = n0 + job % xtiles * RN;
2154
- kernel<RM, RN>(ii, jj);
2800
+ this->kernel<RM, RN>(ii, jj);
2155
2801
  }
2156
2802
  }
2157
2803
 
2158
- const TA *const A;
2159
- const block_q8_0 *const B;
2160
- float *C;
2161
- const int64_t k;
2162
- const int64_t lda;
2163
- const int64_t ldb;
2164
- const int64_t ldc;
2165
- const int ith;
2166
- const int nth;
2167
- };
2804
+ template class tinyBLAS_Q0_PPC<block_q4_0>;
2805
+ template class tinyBLAS_Q0_PPC<block_q8_0>;
2168
2806
 
2169
2807
  class tinyBLAS_PPC {
2170
2808
  public:
@@ -2731,6 +3369,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2731
3369
  params->ith, params->nth};
2732
3370
  tb.matmul(m, n);
2733
3371
  return true;
3372
+ #elif defined(__riscv_zvfh)
3373
+ #if LMUL == 1
3374
+ tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
3375
+ k, (const float *)A, lda,
3376
+ (const float *)B, ldb,
3377
+ (float *)C, ldc};
3378
+ #elif LMUL == 2
3379
+ tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
3380
+ k, (const float *)A, lda,
3381
+ (const float *)B, ldb,
3382
+ (float *)C, ldc};
3383
+ #else // LMUL = 4
3384
+ tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
3385
+ k, (const float *)A, lda,
3386
+ (const float *)B, ldb,
3387
+ (float *)C, ldc};
3388
+ #endif
3389
+ return tb.matmul(m, n);
2734
3390
  #else
2735
3391
  return false;
2736
3392
  #endif
@@ -2773,6 +3429,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2773
3429
  tb.matmul(m, n);
2774
3430
  return true;
2775
3431
  }
3432
+ #elif defined(__riscv_zvfbfwma)
3433
+ #if LMUL == 1
3434
+ tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3435
+ k, (const ggml_bf16_t *)A, lda,
3436
+ (const ggml_bf16_t *)B, ldb,
3437
+ (float *)C, ldc};
3438
+ #elif LMUL == 2
3439
+ tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3440
+ k, (const ggml_bf16_t *)A, lda,
3441
+ (const ggml_bf16_t *)B, ldb,
3442
+ (float *)C, ldc};
3443
+ #else // LMUL = 4
3444
+ tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3445
+ k, (const ggml_bf16_t *)A, lda,
3446
+ (const ggml_bf16_t *)B, ldb,
3447
+ (float *)C, ldc};
3448
+ #endif
3449
+ return tb.matmul(m, n);
2776
3450
  #endif
2777
3451
  return false;
2778
3452
  }
@@ -2822,6 +3496,26 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2822
3496
  (float *)C, ldc};
2823
3497
  return tb.matmul(m, n);
2824
3498
  }
3499
+ #elif defined(__riscv_zvfh)
3500
+ if (Btype == GGML_TYPE_F16) {
3501
+ #if LMUL == 1
3502
+ tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3503
+ k, (const ggml_fp16_t *)A, lda,
3504
+ (const ggml_fp16_t *)B, ldb,
3505
+ (float *)C, ldc};
3506
+ #elif LMUL == 2
3507
+ tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3508
+ k, (const ggml_fp16_t *)A, lda,
3509
+ (const ggml_fp16_t *)B, ldb,
3510
+ (float *)C, ldc};
3511
+ #else // LMUL = 4
3512
+ tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3513
+ k, (const ggml_fp16_t *)A, lda,
3514
+ (const ggml_fp16_t *)B, ldb,
3515
+ (float *)C, ldc};
3516
+ #endif
3517
+ return tb.matmul(m, n);
3518
+ }
2825
3519
  #endif
2826
3520
  return false;
2827
3521
  }