whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -69,6 +69,10 @@
69
69
  #define VECTOR_REGISTERS 16
70
70
  #endif
71
71
 
72
+ #if defined(__riscv_v_intrinsic)
73
+ #define LMUL 4
74
+ #endif
75
+
72
76
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
73
77
 
74
78
  namespace {
@@ -176,6 +180,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
176
180
  }
177
181
  #endif
178
182
 
183
+ #if defined(__riscv_zvfh)
184
+ template <>
185
+ inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
186
+ return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
187
+ }
188
+ inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
189
+ return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
190
+ }
191
+ inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
192
+ return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
193
+ }
194
+ inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
195
+ return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
196
+ }
197
+ inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
198
+ return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
199
+ }
200
+ inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
201
+ return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
202
+ }
203
+ inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
204
+ return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
205
+ }
206
+ inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
207
+ return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
208
+ }
209
+ #endif
210
+
211
+ #if defined(__riscv_zvfbfwma)
212
+ inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
213
+ return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
214
+ }
215
+ inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
216
+ return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
217
+ }
218
+ inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
219
+ return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
220
+ }
221
+ #endif
222
+
179
223
  ////////////////////////////////////////////////////////////////////////////////////////////////////
180
224
  // VECTORIZED HORIZONTAL SUM
181
225
 
@@ -228,6 +272,25 @@ inline float hsum(__m512 x) {
228
272
  }
229
273
  #endif // __AVX512F__
230
274
 
275
+ #if defined(__riscv_zvfh)
276
+ inline float hsum(vfloat32m1_t x) {
277
+ return __riscv_vfmv_f_s_f32m1_f32(
278
+ __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
279
+ }
280
+ inline float hsum(vfloat32m2_t x) {
281
+ return __riscv_vfmv_f_s_f32m1_f32(
282
+ __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
283
+ }
284
+ inline float hsum(vfloat32m4_t x) {
285
+ return __riscv_vfmv_f_s_f32m1_f32(
286
+ __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
287
+ }
288
+ inline float hsum(vfloat32m8_t x) {
289
+ return __riscv_vfmv_f_s_f32m1_f32(
290
+ __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
291
+ }
292
+ #endif
293
+
231
294
  ////////////////////////////////////////////////////////////////////////////////////////////////////
232
295
  // VECTORIZED MEMORY LOADING
233
296
 
@@ -316,6 +379,88 @@ template <> inline __m256bh load(const float *p) {
316
379
  }
317
380
  #endif
318
381
 
382
+ #if defined(__riscv_zvfh)
383
+ template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
384
+ return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
385
+ }
386
+ template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
387
+ return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
388
+ }
389
+ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
390
+ return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
391
+ }
392
+ template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
393
+ return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
394
+ }
395
+ template <> inline vfloat32m1_t load(const float *p) {
396
+ return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
397
+ }
398
+ template <> inline vfloat32m2_t load(const float *p) {
399
+ return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
400
+ }
401
+ template <> inline vfloat32m4_t load(const float *p) {
402
+ return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
403
+ }
404
+ template <> inline vfloat32m8_t load(const float *p) {
405
+ return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
406
+ }
407
+ #endif
408
+
409
+ #if defined(__riscv_zvfbfwma)
410
+ template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
411
+ return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
412
+ }
413
+ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
414
+ return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
415
+ }
416
+ template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
417
+ return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
418
+ }
419
+ #endif
420
+
421
+ #if defined(__riscv_zvfh)
422
+ template <typename T> T set_zero();
423
+
424
+ template <> inline vfloat16mf2_t set_zero() {
425
+ return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
426
+ }
427
+ template <> inline vfloat16m1_t set_zero() {
428
+ return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
429
+ }
430
+ template <> inline vfloat16m2_t set_zero() {
431
+ return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
432
+ }
433
+ template <> inline vfloat16m4_t set_zero() {
434
+ return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
435
+ }
436
+ template <> inline vfloat32m1_t set_zero() {
437
+ return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
438
+ }
439
+ template <> inline vfloat32m2_t set_zero() {
440
+ return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
441
+ }
442
+ template <> inline vfloat32m4_t set_zero() {
443
+ return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
444
+ }
445
+ template <> inline vfloat32m8_t set_zero() {
446
+ return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
447
+ }
448
+ #endif
449
+
450
+ #if defined(__riscv_v_intrinsic)
451
+ template <typename T> size_t vlmax() {
452
+ if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
453
+ else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
454
+ else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
455
+ else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
456
+ else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
457
+ else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
458
+ else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
459
+ else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
460
+ return 0;
461
+ }
462
+ #endif
463
+
319
464
  ////////////////////////////////////////////////////////////////////////////////////////////////////
320
465
  // FLOATING POINT MATRIX MULTIPLICATION
321
466
 
@@ -388,7 +533,7 @@ class tinyBLAS {
388
533
  if constexpr (RN > 1) {
389
534
  return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
390
535
  } else {
391
- GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
536
+ GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
392
537
  GGML_ASSERT(false); // we have miss something.
393
538
  }
394
539
  }
@@ -489,6 +634,573 @@ class tinyBLAS {
489
634
  const int64_t ldc;
490
635
  };
491
636
 
637
+ #if defined(__riscv_v_intrinsic)
638
+ template <typename D, typename V, typename TA, typename TB, typename TC>
639
+ class tinyBLAS_RVV {
640
+ public:
641
+ tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
642
+ const TA *A, int64_t lda,
643
+ const TB *B, int64_t ldb,
644
+ TC *C, int64_t ldc)
645
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
646
+ }
647
+
648
+ bool matmul(int64_t m, int64_t n) {
649
+ if (k % vlmax<V>() != 0) {
650
+ return false;
651
+ }
652
+
653
+ #if LMUL == 1
654
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
655
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
656
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
657
+ return true;
658
+ }
659
+ if (m % 8 == 0 ) {
660
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
661
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
662
+ return true;
663
+ }
664
+ if (m % 4 == 0) {
665
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
666
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
667
+ return true;
668
+ }
669
+ #elif LMUL == 2
670
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
671
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
672
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
673
+ return true;
674
+ }
675
+ if (m % 8 == 0 ) {
676
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
677
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
678
+ return true;
679
+ }
680
+ if (m % 4 == 0) {
681
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
682
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
683
+ return true;
684
+ }
685
+ #else // LMUL = 4
686
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
687
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
688
+ mnpack<2, 2, 8>(m, n, SIZE_N, 36);
689
+ return true;
690
+ }
691
+ if (m % 8 == 0 ) {
692
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
693
+ mnpack<2, 2, 4>(m, n, SIZE_N, 36);
694
+ return true;
695
+ }
696
+ if (m % 4 == 0) {
697
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
698
+ mnpack<2, 2, 2>(m, n, SIZE_N, 36);
699
+ return true;
700
+ }
701
+ #endif
702
+ return false;
703
+ }
704
+
705
+ private:
706
+ template<int RM, int RN, int BM>
707
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
708
+ if (SIZE_N == RN) {
709
+ return gemm<RM, RN, BM>(m, n, BN);
710
+ }
711
+ if constexpr (RN > 1) {
712
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
713
+ } else {
714
+ GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
715
+ GGML_ASSERT(false); // we have miss something.
716
+ }
717
+ }
718
+
719
+ inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
720
+ size_t vl = vlmax<V>();
721
+ D Cv00 = set_zero<D>();
722
+ D Cv01 = set_zero<D>();
723
+ D Cv02 = set_zero<D>();
724
+ D Cv03 = set_zero<D>();
725
+ D Cv10 = set_zero<D>();
726
+ D Cv11 = set_zero<D>();
727
+ D Cv12 = set_zero<D>();
728
+ D Cv13 = set_zero<D>();
729
+ D Cv20 = set_zero<D>();
730
+ D Cv21 = set_zero<D>();
731
+ D Cv22 = set_zero<D>();
732
+ D Cv23 = set_zero<D>();
733
+ D Cv30 = set_zero<D>();
734
+ D Cv31 = set_zero<D>();
735
+ D Cv32 = set_zero<D>();
736
+ D Cv33 = set_zero<D>();
737
+ D Cv40 = set_zero<D>();
738
+ D Cv41 = set_zero<D>();
739
+ D Cv42 = set_zero<D>();
740
+ D Cv43 = set_zero<D>();
741
+ D Cv50 = set_zero<D>();
742
+ D Cv51 = set_zero<D>();
743
+ D Cv52 = set_zero<D>();
744
+ D Cv53 = set_zero<D>();
745
+
746
+ for (int64_t l = 0; l < k; l += vl) {
747
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
748
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
749
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
750
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
751
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
752
+ V Bv5 = load<V>(B + ldb * (jj + 5) + l);
753
+
754
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
755
+ Cv00 = madd(Av0, Bv0, Cv00);
756
+ Cv10 = madd(Av0, Bv1, Cv10);
757
+ Cv20 = madd(Av0, Bv2, Cv20);
758
+ Cv30 = madd(Av0, Bv3, Cv30);
759
+ Cv40 = madd(Av0, Bv4, Cv40);
760
+ Cv50 = madd(Av0, Bv5, Cv50);
761
+
762
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
763
+ Cv01 = madd(Av1, Bv0, Cv01);
764
+ Cv11 = madd(Av1, Bv1, Cv11);
765
+ Cv21 = madd(Av1, Bv2, Cv21);
766
+ Cv31 = madd(Av1, Bv3, Cv31);
767
+ Cv41 = madd(Av1, Bv4, Cv41);
768
+ Cv51 = madd(Av1, Bv5, Cv51);
769
+
770
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
771
+ Cv02 = madd(Av2, Bv0, Cv02);
772
+ Cv12 = madd(Av2, Bv1, Cv12);
773
+ Cv22 = madd(Av2, Bv2, Cv22);
774
+ Cv32 = madd(Av2, Bv3, Cv32);
775
+ Cv42 = madd(Av2, Bv4, Cv42);
776
+ Cv52 = madd(Av2, Bv5, Cv52);
777
+
778
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
779
+ Cv03 = madd(Av3, Bv0, Cv03);
780
+ Cv13 = madd(Av3, Bv1, Cv13);
781
+ Cv23 = madd(Av3, Bv2, Cv23);
782
+ Cv33 = madd(Av3, Bv3, Cv33);
783
+ Cv43 = madd(Av3, Bv4, Cv43);
784
+ Cv53 = madd(Av3, Bv5, Cv53);
785
+ }
786
+
787
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
788
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
789
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
790
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
791
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
792
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
793
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
794
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
795
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
796
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
797
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
798
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
799
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
800
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
801
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
802
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
803
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
804
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
805
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
806
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
807
+ C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
808
+ C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
809
+ C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
810
+ C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
811
+ }
812
+
813
+ inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
814
+ size_t vl = vlmax<V>();
815
+ D Cv00 = set_zero<D>();
816
+ D Cv01 = set_zero<D>();
817
+ D Cv02 = set_zero<D>();
818
+ D Cv03 = set_zero<D>();
819
+ D Cv10 = set_zero<D>();
820
+ D Cv11 = set_zero<D>();
821
+ D Cv12 = set_zero<D>();
822
+ D Cv13 = set_zero<D>();
823
+ D Cv20 = set_zero<D>();
824
+ D Cv21 = set_zero<D>();
825
+ D Cv22 = set_zero<D>();
826
+ D Cv23 = set_zero<D>();
827
+ D Cv30 = set_zero<D>();
828
+ D Cv31 = set_zero<D>();
829
+ D Cv32 = set_zero<D>();
830
+ D Cv33 = set_zero<D>();
831
+ D Cv40 = set_zero<D>();
832
+ D Cv41 = set_zero<D>();
833
+ D Cv42 = set_zero<D>();
834
+ D Cv43 = set_zero<D>();
835
+
836
+ for (int64_t l = 0; l < k; l += vl) {
837
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
838
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
839
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
840
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
841
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
842
+
843
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
844
+ Cv00 = madd(Av0, Bv0, Cv00);
845
+ Cv10 = madd(Av0, Bv1, Cv10);
846
+ Cv20 = madd(Av0, Bv2, Cv20);
847
+ Cv30 = madd(Av0, Bv3, Cv30);
848
+ Cv40 = madd(Av0, Bv4, Cv40);
849
+
850
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
851
+ Cv01 = madd(Av1, Bv0, Cv01);
852
+ Cv11 = madd(Av1, Bv1, Cv11);
853
+ Cv21 = madd(Av1, Bv2, Cv21);
854
+ Cv31 = madd(Av1, Bv3, Cv31);
855
+ Cv41 = madd(Av1, Bv4, Cv41);
856
+
857
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
858
+ Cv02 = madd(Av2, Bv0, Cv02);
859
+ Cv12 = madd(Av2, Bv1, Cv12);
860
+ Cv22 = madd(Av2, Bv2, Cv22);
861
+ Cv32 = madd(Av2, Bv3, Cv32);
862
+ Cv42 = madd(Av2, Bv4, Cv42);
863
+
864
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
865
+ Cv03 = madd(Av3, Bv0, Cv03);
866
+ Cv13 = madd(Av3, Bv1, Cv13);
867
+ Cv23 = madd(Av3, Bv2, Cv23);
868
+ Cv33 = madd(Av3, Bv3, Cv33);
869
+ Cv43 = madd(Av3, Bv4, Cv43);
870
+ }
871
+
872
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
873
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
874
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
875
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
876
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
877
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
878
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
879
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
880
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
881
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
882
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
883
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
884
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
885
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
886
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
887
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
888
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
889
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
890
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
891
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
892
+ }
893
+
894
+ inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
895
+ size_t vl = vlmax<V>();
896
+ D Cv00 = set_zero<D>();
897
+ D Cv01 = set_zero<D>();
898
+ D Cv02 = set_zero<D>();
899
+ D Cv03 = set_zero<D>();
900
+ D Cv10 = set_zero<D>();
901
+ D Cv11 = set_zero<D>();
902
+ D Cv12 = set_zero<D>();
903
+ D Cv13 = set_zero<D>();
904
+ D Cv20 = set_zero<D>();
905
+ D Cv21 = set_zero<D>();
906
+ D Cv22 = set_zero<D>();
907
+ D Cv23 = set_zero<D>();
908
+ D Cv30 = set_zero<D>();
909
+ D Cv31 = set_zero<D>();
910
+ D Cv32 = set_zero<D>();
911
+ D Cv33 = set_zero<D>();
912
+
913
+ for (int64_t l = 0; l < k; l += vl) {
914
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
915
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
916
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
917
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
918
+
919
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
920
+ Cv00 = madd(Av0, Bv0, Cv00);
921
+ Cv01 = madd(Av1, Bv0, Cv01);
922
+ Cv02 = madd(Av2, Bv0, Cv02);
923
+ Cv03 = madd(Av3, Bv0, Cv03);
924
+
925
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
926
+ Cv10 = madd(Av0, Bv1, Cv10);
927
+ Cv11 = madd(Av1, Bv1, Cv11);
928
+ Cv12 = madd(Av2, Bv1, Cv12);
929
+ Cv13 = madd(Av3, Bv1, Cv13);
930
+
931
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
932
+ Cv20 = madd(Av0, Bv2, Cv20);
933
+ Cv21 = madd(Av1, Bv2, Cv21);
934
+ Cv22 = madd(Av2, Bv2, Cv22);
935
+ Cv23 = madd(Av3, Bv2, Cv23);
936
+
937
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
938
+ Cv30 = madd(Av0, Bv3, Cv30);
939
+ Cv31 = madd(Av1, Bv3, Cv31);
940
+ Cv32 = madd(Av2, Bv3, Cv32);
941
+ Cv33 = madd(Av3, Bv3, Cv33);
942
+ }
943
+
944
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
945
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
946
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
947
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
948
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
949
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
950
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
951
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
952
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
953
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
954
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
955
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
956
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
957
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
958
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
959
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
960
+ }
961
+
962
+ inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
963
+ size_t vl = vlmax<V>();
964
+ D Cv00 = set_zero<D>();
965
+ D Cv01 = set_zero<D>();
966
+ D Cv02 = set_zero<D>();
967
+ D Cv03 = set_zero<D>();
968
+ D Cv10 = set_zero<D>();
969
+ D Cv11 = set_zero<D>();
970
+ D Cv12 = set_zero<D>();
971
+ D Cv13 = set_zero<D>();
972
+ D Cv20 = set_zero<D>();
973
+ D Cv21 = set_zero<D>();
974
+ D Cv22 = set_zero<D>();
975
+ D Cv23 = set_zero<D>();
976
+
977
+ for (int64_t l = 0; l < k; l += vl) {
978
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
979
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
980
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
981
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
982
+
983
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
984
+ Cv00 = madd(Av0, Bv0, Cv00);
985
+ Cv01 = madd(Av1, Bv0, Cv01);
986
+ Cv02 = madd(Av2, Bv0, Cv02);
987
+ Cv03 = madd(Av3, Bv0, Cv03);
988
+
989
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
990
+ Cv10 = madd(Av0, Bv1, Cv10);
991
+ Cv11 = madd(Av1, Bv1, Cv11);
992
+ Cv12 = madd(Av2, Bv1, Cv12);
993
+ Cv13 = madd(Av3, Bv1, Cv13);
994
+
995
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
996
+ Cv20 = madd(Av0, Bv2, Cv20);
997
+ Cv21 = madd(Av1, Bv2, Cv21);
998
+ Cv22 = madd(Av2, Bv2, Cv22);
999
+ Cv23 = madd(Av3, Bv2, Cv23);
1000
+ }
1001
+
1002
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1003
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1004
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1005
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1006
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1007
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1008
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1009
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1010
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
1011
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
1012
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
1013
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
1014
+ }
1015
+
1016
+ inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
1017
+ size_t vl = vlmax<V>();
1018
+ D Cv00 = set_zero<D>();
1019
+ D Cv01 = set_zero<D>();
1020
+ D Cv02 = set_zero<D>();
1021
+ D Cv03 = set_zero<D>();
1022
+ D Cv10 = set_zero<D>();
1023
+ D Cv11 = set_zero<D>();
1024
+ D Cv12 = set_zero<D>();
1025
+ D Cv13 = set_zero<D>();
1026
+
1027
+ for (int64_t l = 0; l < k; l += vl) {
1028
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1029
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1030
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1031
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1032
+
1033
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1034
+ Cv00 = madd(Av0, Bv0, Cv00);
1035
+ Cv01 = madd(Av1, Bv0, Cv01);
1036
+ Cv02 = madd(Av2, Bv0, Cv02);
1037
+ Cv03 = madd(Av3, Bv0, Cv03);
1038
+
1039
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1040
+ Cv10 = madd(Av0, Bv1, Cv10);
1041
+ Cv11 = madd(Av1, Bv1, Cv11);
1042
+ Cv12 = madd(Av2, Bv1, Cv12);
1043
+ Cv13 = madd(Av3, Bv1, Cv13);
1044
+ }
1045
+
1046
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1047
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1048
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1049
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1050
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1051
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1052
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1053
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1054
+ }
1055
+
1056
+ inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
1057
+ size_t vl = vlmax<V>();
1058
+ D Cv00 = set_zero<D>();
1059
+ D Cv01 = set_zero<D>();
1060
+ D Cv02 = set_zero<D>();
1061
+ D Cv03 = set_zero<D>();
1062
+
1063
+ for (int64_t l = 0; l < k; l += vl) {
1064
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1065
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1066
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1067
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1068
+
1069
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1070
+ Cv00 = madd(Av0, Bv0, Cv00);
1071
+ Cv01 = madd(Av1, Bv0, Cv01);
1072
+ Cv02 = madd(Av2, Bv0, Cv02);
1073
+ Cv03 = madd(Av3, Bv0, Cv03);
1074
+ }
1075
+
1076
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1077
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1078
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1079
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1080
+ }
1081
+
1082
+ inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
1083
+ size_t vl = vlmax<V>();
1084
+ D Cv00 = set_zero<D>();
1085
+ D Cv01 = set_zero<D>();
1086
+ D Cv10 = set_zero<D>();
1087
+ D Cv11 = set_zero<D>();
1088
+
1089
+ for (int64_t l = 0; l < k; l += vl) {
1090
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1091
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1092
+
1093
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1094
+ Cv00 = madd(Av0, Bv0, Cv00);
1095
+ Cv01 = madd(Av1, Bv0, Cv01);
1096
+
1097
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1098
+ Cv10 = madd(Av0, Bv1, Cv10);
1099
+ Cv11 = madd(Av1, Bv1, Cv11);
1100
+ }
1101
+
1102
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1103
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1104
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1105
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1106
+ }
1107
+
1108
+ inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
1109
+ size_t vl = vlmax<V>();
1110
+ D Cv00 = set_zero<D>();
1111
+ D Cv01 = set_zero<D>();
1112
+
1113
+ for (int64_t l = 0; l < k; l += vl) {
1114
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1115
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1116
+
1117
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1118
+ Cv00 = madd(Av0, Bv0, Cv00);
1119
+ Cv01 = madd(Av1, Bv0, Cv01);
1120
+ }
1121
+
1122
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1123
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1124
+ }
1125
+
1126
+ template <int RM, int RN>
1127
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
1128
+ if constexpr (RM == 4) {
1129
+ if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
1130
+ if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
1131
+ if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
1132
+ if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
1133
+ if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
1134
+ if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
1135
+ } else if constexpr (RM == 2) {
1136
+ if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
1137
+ if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
1138
+ }
1139
+ }
1140
+
1141
+ template <int RM, int RN, int BM>
1142
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
1143
+ GGML_ASSERT(m % (RM * BM) == 0);
1144
+ const int64_t ytiles = m / (RM * BM);
1145
+ const int64_t xtiles = (n + RN -1) / RN;
1146
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
1147
+
1148
+ // "round" bloc_size to "nearest" BN
1149
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
1150
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
1151
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
1152
+ const int64_t nb_job = ytiles * NB_BN;
1153
+
1154
+ if (params->ith == 0) {
1155
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
1156
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1157
+ ggml_threadpool_chunk_set(params->threadpool, params->nth);
1158
+ }
1159
+
1160
+ ggml_barrier(params->threadpool);
1161
+
1162
+ int64_t job = params->ith;
1163
+ while (job < nb_job) {
1164
+ const int64_t ii = (job % ytiles) * RM * BM;
1165
+ const int64_t jb = job / ytiles;
1166
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
1167
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
1168
+
1169
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
1170
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
1171
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
1172
+
1173
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
1174
+ int64_t jj = jj0;
1175
+ for (; jj < jj1; jj += RN) {
1176
+ gemm_bloc<RM, RN>(ii + bi, jj);
1177
+ }
1178
+ if constexpr (RN > 1) {
1179
+ for (; jj < jj2; jj += RN - 1) {
1180
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
1181
+ }
1182
+ }
1183
+ GGML_ASSERT(jj == jj2);
1184
+ }
1185
+
1186
+ job = ggml_threadpool_chunk_add(params->threadpool, 1);
1187
+ }
1188
+
1189
+ ggml_barrier(params->threadpool);
1190
+ return;
1191
+ }
1192
+
1193
+ const ggml_compute_params * params;
1194
+ const TA *const A;
1195
+ const TB *const B;
1196
+ TC *const C;
1197
+ const int64_t k;
1198
+ const int64_t lda;
1199
+ const int64_t ldb;
1200
+ const int64_t ldc;
1201
+ };
1202
+ #endif
1203
+
492
1204
  //////////////////////////////////////////////////////////////////////////////////////////
493
1205
  // QUANT ZERO MATRIX MULTIPLICATION
494
1206
 
@@ -1086,10 +1798,27 @@ class tinyBLAS_Q0_AVX {
1086
1798
  } \
1087
1799
  } \
1088
1800
 
1801
+ template<typename T>
1802
+ struct mma_instr;
1803
+
1804
+ template<>
1805
+ struct mma_instr<ggml_bf16_t> {
1806
+ static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1807
+ __builtin_mma_xvbf16ger2pp(acc, a, b);
1808
+ }
1809
+ };
1810
+
1811
+ template<>
1812
+ struct mma_instr<ggml_fp16_t> {
1813
+ static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1814
+ __builtin_mma_xvf16ger2pp(acc, a, b);
1815
+ }
1816
+ };
1817
+
1089
1818
  template <typename TA, typename TB, typename TC>
1090
- class tinyBLAS_BF16_PPC {
1819
+ class tinyBLAS_HP16_PPC {
1091
1820
  public:
1092
- tinyBLAS_BF16_PPC(int64_t k,
1821
+ tinyBLAS_HP16_PPC(int64_t k,
1093
1822
  const TA *A, int64_t lda,
1094
1823
  const TB *B, int64_t ldb,
1095
1824
  TC *C, int64_t ldc,
@@ -1407,8 +2136,8 @@ class tinyBLAS_BF16_PPC {
1407
2136
  packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
1408
2137
  packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
1409
2138
  for (int x = 0; x < 4; x++) {
1410
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1411
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
2139
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2140
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
1412
2141
  }
1413
2142
  }
1414
2143
  SAVE_ACC(&acc_0, ii, jj);
@@ -1424,8 +2153,8 @@ class tinyBLAS_BF16_PPC {
1424
2153
  packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
1425
2154
  packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
1426
2155
  for (int x = 0; x < 4; x++) {
1427
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1428
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
2156
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2157
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
1429
2158
  }
1430
2159
  }
1431
2160
  SAVE_ACC(&acc_0, ii, jj);
@@ -1444,10 +2173,10 @@ class tinyBLAS_BF16_PPC {
1444
2173
  packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
1445
2174
  packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
1446
2175
  for (int x = 0; x < 4; x++) {
1447
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1448
- __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
1449
- __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
1450
- __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
2176
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2177
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2178
+ mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
2179
+ mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
1451
2180
  }
1452
2181
  }
1453
2182
 
@@ -1478,7 +2207,7 @@ class tinyBLAS_BF16_PPC {
1478
2207
  packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
1479
2208
  packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
1480
2209
  for (int x = 0; x<2; x++) {
1481
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2210
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
1482
2211
  }
1483
2212
  }
1484
2213
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -1513,8 +2242,8 @@ class tinyBLAS_BF16_PPC {
1513
2242
  packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
1514
2243
  packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
1515
2244
  for (int x = 0; x<4; x++) {
1516
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1517
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
2245
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2246
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
1518
2247
  }
1519
2248
  }
1520
2249
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -1577,44 +2306,91 @@ template <typename TA>
1577
2306
  class tinyBLAS_Q0_PPC {
1578
2307
  public:
1579
2308
  tinyBLAS_Q0_PPC(int64_t k,
1580
- const TA *A, int64_t lda,
1581
- const block_q8_0 *B, int64_t ldb,
1582
- float *C, int64_t ldc,
1583
- int ith, int nth)
2309
+ const TA * A, int64_t lda,
2310
+ const block_q8_0 * B, int64_t ldb,
2311
+ float * C, int64_t ldc,
2312
+ int ith, int nth)
1584
2313
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1585
2314
  }
1586
2315
 
1587
2316
  void matmul(int64_t m, int64_t n) {
1588
- mnpack(0, m, 0, n);
2317
+ const int64_t mc = 64;
2318
+ const int64_t kc = 64;
2319
+ int64_t nc = 64;
2320
+ int64_t n_aligned = 0;
2321
+ if (n % 64 == 0) {
2322
+ n_aligned = n;
2323
+ } else if (n == 4) {
2324
+ n_aligned = 4;
2325
+ } else if (n < 64) {
2326
+ n_aligned = (n / 8) * 8;
2327
+ } else {
2328
+ n_aligned = (n / 64) * 64;
2329
+ }
2330
+
2331
+ if (n_aligned > 0) {
2332
+ if (n_aligned % 64 == 0) nc = 64;
2333
+ else if (n_aligned == n) nc = n;
2334
+ else if (n_aligned % 32 == 0) nc = 32;
2335
+ else if (n_aligned % 24 == 0) nc = 24;
2336
+ else if (n_aligned % 16 == 0) nc = 16;
2337
+ else nc = 8;
2338
+ }
2339
+ bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
2340
+ if (can_use_tiled) {
2341
+ matmul_tiled(m, n_aligned, mc, nc, kc);
2342
+ if (n > n_aligned) {
2343
+ mnpack(0, m, n_aligned, n);
2344
+ }
2345
+ } else {
2346
+ mnpack(0, m, 0, n);
2347
+ }
1589
2348
  }
1590
2349
 
1591
2350
  private:
2351
+ inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
2352
+ for (int I = 0; I < RM; I++) {
2353
+ for (int J = 0; J < RN; J++) {
2354
+ *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
2355
+ }
2356
+ }
2357
+ }
1592
2358
 
1593
- inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1594
- for (int I = 0; I < RM; I++) {
1595
- for (int J = 0; J < RN; J++) {
1596
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1597
- }
1598
- }
2359
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2360
+ vec_t vec_C[4];
2361
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2362
+ for (int I = 0; I < 4; I++) {
2363
+ for (int J = 0; J < 4; J++) {
2364
+ *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
2365
+ }
2366
+ }
1599
2367
  }
1600
2368
 
1601
- template<int size>
1602
- inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1603
- vector signed int vec_C[4];
1604
- vector float CA[4] = {0};
1605
- vector float res[4] = {0};
1606
- __builtin_mma_disassemble_acc(vec_C, ACC);
1607
- for (int i = 0; i < 4; i++) {
1608
- CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1609
- res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1610
- fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611
- }
2369
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2370
+ vec_t vec_C[4];
2371
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2372
+ for (int I = 0; I < 4; I++) {
2373
+ for (int J = 0; J < 4; J++) {
2374
+ float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
2375
+ *c_ptr += *((float *)&vec_C[I] + J);
2376
+ }
2377
+ }
1612
2378
  }
1613
- /* This function processes quantized data from block_q4_0 elements.
1614
- * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
1615
- * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
1616
- * Also compute the rowsum which is required to compensate the above conversion. */
1617
- inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
2379
+
2380
+ template<typename ArrayType>
2381
+ inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
2382
+ vector signed int vec_C[4];
2383
+ vector float CA[4] = {0};
2384
+ vector float res[4] = {0};
2385
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2386
+ for (int i = 0; i < 4; i++) {
2387
+ CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
2388
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
2389
+ fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
2390
+ }
2391
+ }
2392
+
2393
+ inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
1618
2394
  const vector signed char lowMask = vec_splats((signed char)0xF);
1619
2395
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1620
2396
  const vector signed char v8 = vec_splats((signed char)0x8);
@@ -1631,7 +2407,7 @@ class tinyBLAS_Q0_PPC {
1631
2407
  }
1632
2408
 
1633
2409
  template <typename V1, typename V2>
1634
- inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
2410
+ inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
1635
2411
  vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1636
2412
  vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1637
2413
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
@@ -1655,21 +2431,170 @@ class tinyBLAS_Q0_PPC {
1655
2431
  t8 = vec_xor(t8, xor_vector);
1656
2432
  }
1657
2433
  vec_xst(t5, 0, vecOffset);
1658
- vec_xst(t6, 0, vecOffset+16);
1659
- vec_xst(t7, 0, vecOffset+32);
1660
- vec_xst(t8, 0, vecOffset+48);
2434
+ vec_xst(t6, 0, vecOffset + 16);
2435
+ vec_xst(t7, 0, vecOffset + 32);
2436
+ vec_xst(t8, 0, vecOffset + 48);
2437
+ }
2438
+
2439
+ inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
2440
+ const vector signed char lowMask = vec_splats((signed char)0x0F);
2441
+ const vector signed char v8 = vec_splats((signed char)0x08);
2442
+ const vector unsigned char v4 = vec_splats((unsigned char)4);
2443
+ lo = vec_and(packed, lowMask);
2444
+ hi = vec_sr(packed, v4);
2445
+ lo = vec_sub(lo, v8);
2446
+ hi = vec_sub(hi, v8);
2447
+ }
2448
+
2449
+ inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
2450
+ vec_t t[8], s[8];
2451
+ vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
2452
+ vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
2453
+ vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2454
+ vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2455
+ for (int i = 0; i < 4; i += 2) {
2456
+ t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
2457
+ t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
2458
+ }
2459
+ for (int i = 4; i < 8; i += 2) {
2460
+ t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
2461
+ t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
2462
+ }
2463
+ s[0] = vec_perm(t[0], t[2], swiz3);
2464
+ s[1] = vec_perm(t[0], t[2], swiz4);
2465
+ s[2] = vec_perm(t[1], t[3], swiz3);
2466
+ s[3] = vec_perm(t[1], t[3], swiz4);
2467
+ s[4] = vec_perm(t[4], t[6], swiz3);
2468
+ s[5] = vec_perm(t[4], t[6], swiz4);
2469
+ s[6] = vec_perm(t[5], t[7], swiz3);
2470
+ s[7] = vec_perm(t[5], t[7], swiz4);
2471
+ for (int i = 0; i < 8; ++i) {
2472
+ vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
2473
+ }
2474
+ }
2475
+
2476
+ static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
2477
+ vector signed short i16_hi = vec_unpackh(raw);
2478
+ vector signed short i16_lo = vec_unpackl(raw);
2479
+
2480
+ vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
2481
+ vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
2482
+ vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
2483
+ vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
2484
+ out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
2485
+ out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
2486
+ }
2487
+
2488
+ void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2489
+ unsigned char * vecOffset = vec;
2490
+ for (int i = 0; i < rows; i += 8) {
2491
+ const block_q4_0 * rows_base[8];
2492
+ for (int r = 0; r < 8; r++) {
2493
+ rows_base[r] = a + (i + r) * lda;
2494
+ }
2495
+ for (int blk = 0; blk < blocks; blk++) {
2496
+ vector unsigned short hp_res[8][4];
2497
+ for (int r = 0; r < 8; r++) {
2498
+ const block_q4_0 * current_blk = rows_base[r] + blk;
2499
+ vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
2500
+ vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);
2501
+ vector signed char c1, c2;
2502
+ unpack_q4_to_q8(v_qs, c1, c2);
2503
+ convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
2504
+ convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
2505
+ }
2506
+ for (int c = 0; c < 4; c++) {
2507
+ vector unsigned char c_arr[8];
2508
+ for (int r = 0; r < 8; r++) {
2509
+ c_arr[r] = (vector unsigned char)hp_res[r][c];
2510
+ }
2511
+ vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
2512
+ vecOffset += 128;
2513
+ }
2514
+ }
2515
+ }
2516
+ }
2517
+
2518
+ template <int chunk_size>
2519
+ static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2520
+ unsigned char * vecOffset = vec;
2521
+ const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
2522
+ const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
2523
+ const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2524
+ const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2525
+
2526
+ for (int i = 0; i < rows; i += chunk_size) {
2527
+ const block_q8_0 * rows_base[chunk_size];
2528
+ for (int r = 0; r < chunk_size; r++) {
2529
+ rows_base[r] = a + (i + r) * lda;
2530
+ }
2531
+ for (int blk = 0; blk < blocks; blk++) {
2532
+ vector unsigned short hp_res[chunk_size][4];
2533
+ for (int r = 0; r < chunk_size; r++) {
2534
+ const block_q8_0 * b = rows_base[r] + blk;
2535
+ vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
2536
+ vector signed char c[2];
2537
+ __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
2538
+ __builtin_vsx_disassemble_pair(c, & pair);
2539
+ convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
2540
+ convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
2541
+ }
2542
+ for (int col = 0; col < 4; col++) {
2543
+ if constexpr (chunk_size == 8) {
2544
+ vec_t t[8];
2545
+ t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
2546
+ t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
2547
+ t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
2548
+ t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
2549
+ t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
2550
+ t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
2551
+ t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
2552
+ t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
2553
+
2554
+ vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
2555
+ vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
2556
+ vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
2557
+ vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
2558
+ vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
2559
+ vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
2560
+ vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
2561
+ vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
2562
+ vecOffset += 128;
2563
+ } else {
2564
+ vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
2565
+ vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
2566
+ vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
2567
+ vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
2568
+
2569
+ vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
2570
+ vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
2571
+ vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
2572
+ vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
2573
+ vecOffset += 64;
2574
+ }
2575
+ }
2576
+ }
2577
+ }
2578
+ }
2579
+
2580
+ void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2581
+ if (rows == 4) {
2582
+ pack_q8_block<4>(a, lda, rows, blocks, vec);
2583
+ } else {
2584
+ pack_q8_block<8>(a, lda, rows, blocks, vec);
2585
+ }
1661
2586
  }
1662
2587
 
1663
2588
  template<int size>
1664
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
2589
+ void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
1665
2590
  int64_t i, j;
1666
- TA *aoffset = NULL;
1667
- int8_t *vecOffset = NULL;
1668
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1669
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2591
+ TA * aoffset = NULL;
2592
+ int8_t * vecOffset = NULL;
2593
+ TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
2594
+ TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
1670
2595
  vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1671
2596
  vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1672
- aoffset = const_cast<TA*>(a);
2597
+ aoffset = const_cast<TA *>(a);
1673
2598
  vecOffset = vec;
1674
2599
  j = (rows >> 3);
1675
2600
  if (j > 0) {
@@ -1686,27 +2611,27 @@ class tinyBLAS_Q0_PPC {
1686
2611
  i = (cols >> 2);
1687
2612
  if (i > 0) {
1688
2613
  do {
1689
- c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1690
- c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1691
- c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1692
- c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1693
- c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1694
- c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1695
- c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1696
- c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1697
-
1698
- process_q4_elements(c1, &comparray[0]);
1699
- process_q4_elements(c2, &comparray[1]);
1700
- process_q4_elements(c3, &comparray[2]);
1701
- process_q4_elements(c4, &comparray[3]);
1702
- process_q4_elements(c5, &comparray[4]);
1703
- process_q4_elements(c6, &comparray[5]);
1704
- process_q4_elements(c7, &comparray[6]);
1705
- process_q4_elements(c8, &comparray[7]);
2614
+ c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2615
+ c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2616
+ c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2617
+ c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
2618
+ c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);
2619
+ c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);
2620
+ c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);
2621
+ c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);
2622
+
2623
+ process_q4_elements(c1, & comparray[0]);
2624
+ process_q4_elements(c2, & comparray[1]);
2625
+ process_q4_elements(c3, & comparray[2]);
2626
+ process_q4_elements(c4, & comparray[3]);
2627
+ process_q4_elements(c5, & comparray[4]);
2628
+ process_q4_elements(c6, & comparray[5]);
2629
+ process_q4_elements(c7, & comparray[6]);
2630
+ process_q4_elements(c8, & comparray[7]);
1706
2631
  vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1707
- vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1708
- vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
1709
- vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
2632
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
2633
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
2634
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
1710
2635
  aoffset1 += lda;
1711
2636
  aoffset2 += lda;
1712
2637
  aoffset3 += lda;
@@ -1732,17 +2657,17 @@ class tinyBLAS_Q0_PPC {
1732
2657
  i = (cols >> 2);
1733
2658
  if (i > 0) {
1734
2659
  do {
1735
- c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1736
- c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1737
- c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1738
- c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1739
-
1740
- process_q4_elements(c1, &comparray[0]);
1741
- process_q4_elements(c2, &comparray[1]);
1742
- process_q4_elements(c3, &comparray[2]);
1743
- process_q4_elements(c4, &comparray[3]);
2660
+ c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2661
+ c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2662
+ c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2663
+ c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
2664
+
2665
+ process_q4_elements(c1, & comparray[0]);
2666
+ process_q4_elements(c2, & comparray[1]);
2667
+ process_q4_elements(c3, & comparray[2]);
2668
+ process_q4_elements(c4, & comparray[3]);
1744
2669
  vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1745
- vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2670
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
1746
2671
  aoffset1 += lda;
1747
2672
  aoffset2 += lda;
1748
2673
  aoffset3 += lda;
@@ -1761,17 +2686,17 @@ class tinyBLAS_Q0_PPC {
1761
2686
  if (i > 0) {
1762
2687
  do {
1763
2688
  switch(rows) {
1764
- case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1765
- case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1766
- case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2689
+ case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2690
+ case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2691
+ case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
1767
2692
  break;
1768
2693
  }
1769
- process_q4_elements(c1, &comparray[0]);
1770
- process_q4_elements(c2, &comparray[1]);
1771
- process_q4_elements(c3, &comparray[2]);
1772
- process_q4_elements(c4, &comparray[3]);
2694
+ process_q4_elements(c1, & comparray[0]);
2695
+ process_q4_elements(c2, & comparray[1]);
2696
+ process_q4_elements(c3, & comparray[2]);
2697
+ process_q4_elements(c4, & comparray[3]);
1773
2698
  vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1774
- vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2699
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
1775
2700
  aoffset1 += lda;
1776
2701
  aoffset2 += lda;
1777
2702
  aoffset3 += lda;
@@ -1781,38 +2706,39 @@ class tinyBLAS_Q0_PPC {
1781
2706
  }
1782
2707
  }
1783
2708
  }
2709
+
1784
2710
  template<typename VA, typename VB>
1785
- void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2711
+ void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
1786
2712
  int64_t i, j;
1787
- block_q8_0 *aoffset = NULL;
1788
- VA *vecOffset = NULL;
1789
- block_q8_0* aoffsets[8];
2713
+ block_q8_0 * aoffset = NULL;
2714
+ VA * vecOffset = NULL;
2715
+ block_q8_0 * aoffsets[8];
1790
2716
  __vector_pair arr[8];
1791
2717
  VB c[8][2] = {0};
1792
2718
  VB c1[8] = {0}; VB c2[8] = {0};
1793
- aoffset = const_cast<block_q8_0*>(a);
2719
+ aoffset = const_cast<block_q8_0 *>(a);
1794
2720
  vecOffset = vec;
1795
2721
  j = (rows >> 3);
1796
2722
  if (j > 0) {
1797
2723
  do {
1798
2724
  aoffsets[0] = aoffset;
1799
2725
  for (int it = 1; it < 8; it++)
1800
- aoffsets[it] = aoffsets[it-1] + lda;
2726
+ aoffsets[it] = aoffsets[it - 1] + lda;
1801
2727
  aoffset += 8 * lda;
1802
2728
 
1803
2729
  i = (cols >> 3);
1804
2730
  if (i > 0) {
1805
2731
  do {
1806
2732
  for (int it = 0; it < 8; it++) {
1807
- arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1808
- __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2733
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
2734
+ __builtin_vsx_disassemble_pair(c[it], & arr[it]);
1809
2735
  c1[it] = c[it][0];
1810
2736
  c2[it] = c[it][1];
1811
2737
  }
1812
2738
  vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1813
- vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1814
- vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1815
- vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
2739
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
2740
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
2741
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
1816
2742
  for (int it = 0; it < 8; it++)
1817
2743
  aoffsets[it] += lda;
1818
2744
  vecOffset += 256;
@@ -1822,7 +2748,6 @@ class tinyBLAS_Q0_PPC {
1822
2748
  j--;
1823
2749
  } while(j > 0);
1824
2750
  }
1825
-
1826
2751
  if (rows & 4) {
1827
2752
  aoffsets[0] = aoffset;
1828
2753
  for (int it = 1; it < 4; it++ )
@@ -1832,13 +2757,13 @@ class tinyBLAS_Q0_PPC {
1832
2757
  if (i > 0) {
1833
2758
  do {
1834
2759
  for (int it = 0; it < 4; it++) {
1835
- arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1836
- __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2760
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
2761
+ __builtin_vsx_disassemble_pair(c[it], & arr[it]);
1837
2762
  c1[it] = c[it][0];
1838
2763
  c2[it] = c[it][1];
1839
2764
  }
1840
2765
  vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1841
- vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2766
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
1842
2767
  for (int it = 0; it < 4; it++) {
1843
2768
  aoffsets[it] += lda;
1844
2769
  }
@@ -1851,24 +2776,24 @@ class tinyBLAS_Q0_PPC {
1851
2776
  if (rows & 3) {
1852
2777
  aoffsets[0] = aoffset;
1853
2778
  for (int it = 1; it < 3; it++ )
1854
- aoffsets[it] = aoffsets[it-1] + lda;
2779
+ aoffsets[it] = aoffsets[it - 1] + lda;
1855
2780
  i = (cols >> 3);
1856
2781
  if (i > 0) {
1857
2782
  do {
1858
2783
  switch(rows) {
1859
- case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
1860
- __builtin_vsx_disassemble_pair(c[2], &arr[2]);
2784
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
2785
+ __builtin_vsx_disassemble_pair(c[2], & arr[2]);
1861
2786
  c1[2] = c[2][0]; c2[2] = c[2][1];
1862
- case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
1863
- __builtin_vsx_disassemble_pair(c[1], &arr[1]);
2787
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
2788
+ __builtin_vsx_disassemble_pair(c[1], & arr[1]);
1864
2789
  c1[1] = c[1][0]; c2[1] = c[1][1];
1865
- case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
1866
- __builtin_vsx_disassemble_pair(c[0], &arr[0]);
2790
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
2791
+ __builtin_vsx_disassemble_pair(c[0], & arr[0]);
1867
2792
  c1[0] = c[0][0]; c2[0] = c[0][1];
1868
2793
  break;
1869
2794
  }
1870
2795
  vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1871
- vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2796
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
1872
2797
  for (int it = 0; it < 3; it++)
1873
2798
  aoffsets[it] += lda;
1874
2799
  vecOffset += 128;
@@ -1923,26 +2848,26 @@ class tinyBLAS_Q0_PPC {
1923
2848
  vector float vs[8] = {0};
1924
2849
  bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1925
2850
  for (int l = 0; l < k; l++) {
1926
- __builtin_mma_xxsetaccz(&acc_0);
1927
- __builtin_mma_xxsetaccz(&acc_1);
2851
+ __builtin_mma_xxsetaccz(& acc_0);
2852
+ __builtin_mma_xxsetaccz(& acc_1);
1928
2853
  if (std::is_same_v<TA, block_q4_0>) {
1929
- packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2854
+ packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
1930
2855
  } else {
1931
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2856
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
1932
2857
  }
1933
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2858
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
1934
2859
  for(int x = 0; x < 8; x++) {
1935
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1936
- __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
2860
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2861
+ __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
1937
2862
  }
1938
2863
  for (int I = 0; I<4; I++) {
1939
2864
  for (int J = 0; J<4; J++) {
1940
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1941
- *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2865
+ *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2866
+ *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
1942
2867
  }
1943
2868
  }
1944
2869
  if (!isAblock_q4) {
1945
- auto aoffset = A+(ii*lda)+l;
2870
+ auto aoffset = A + (ii * lda) + l;
1946
2871
  for (int i = 0; i < 4; i++) {
1947
2872
  comparray[i] = 0;
1948
2873
  int ca = 0;
@@ -1953,11 +2878,11 @@ class tinyBLAS_Q0_PPC {
1953
2878
  aoffset += lda;
1954
2879
  }
1955
2880
  }
1956
- compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1957
- compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2881
+ compute(& acc_0, 0, 0, comparray, vs, fin_res);
2882
+ compute(& acc_1, 0, 4, comparray, vs, fin_res);
1958
2883
  }
1959
2884
  save_res(ii, jj, 0, fin_res);
1960
- save_res(ii, jj+4, 4, fin_res);
2885
+ save_res(ii, jj + 4, 4, fin_res);
1961
2886
  }
1962
2887
 
1963
2888
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -1968,25 +2893,25 @@ class tinyBLAS_Q0_PPC {
1968
2893
  vector float vs[8] = {0};
1969
2894
  bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1970
2895
  for (int l = 0; l < k; l++) {
1971
- __builtin_mma_xxsetaccz(&acc_0);
1972
- __builtin_mma_xxsetaccz(&acc_1);
2896
+ __builtin_mma_xxsetaccz(& acc_0);
2897
+ __builtin_mma_xxsetaccz(& acc_1);
1973
2898
  if (std::is_same_v<TA, block_q4_0>) {
1974
- packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2899
+ packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
1975
2900
  } else {
1976
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2901
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
1977
2902
  }
1978
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2903
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
1979
2904
  for(int x = 0; x < 8; x++) {
1980
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1981
- __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2905
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2906
+ __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
1982
2907
  }
1983
- for (int I = 0; I<8; I++) {
1984
- for (int J = 0; J<4; J++) {
1985
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2908
+ for (int I = 0; I < 8; I++) {
2909
+ for (int J = 0; J < 4; J++) {
2910
+ *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
1986
2911
  }
1987
2912
  }
1988
2913
  if (!isAblock_q4) {
1989
- auto aoffset = A+(ii*lda)+l;
2914
+ auto aoffset = A + (ii * lda) + l;
1990
2915
  for (int i = 0; i < 8; i++) {
1991
2916
  comparray[i] = 0;
1992
2917
  int ca = 0;
@@ -1997,45 +2922,46 @@ class tinyBLAS_Q0_PPC {
1997
2922
  aoffset += lda;
1998
2923
  }
1999
2924
  }
2000
- compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2001
- compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2925
+ compute(& acc_0, 0, 0, comparray, vs, fin_res);
2926
+ compute(& acc_1, 4, 4, comparray, vs, fin_res);
2002
2927
  }
2003
2928
  save_res(ii, jj, 0, fin_res);
2004
- save_res(ii+4, jj, 4, fin_res);
2929
+ save_res(ii + 4, jj, 4, fin_res);
2005
2930
  }
2006
2931
 
2007
2932
  void KERNEL_8x8(int64_t ii, int64_t jj) {
2008
2933
  vec_t vec_A[16], vec_B[16] = {0};
2009
2934
  acc_t acc_0, acc_1, acc_2, acc_3;
2935
+ acc_t acc_4, acc_5, acc_6, acc_7;
2010
2936
  std::array<int, 8> comparray {};
2011
2937
  vector float fin_res[16] = {0};
2012
2938
  vector float vs[16] = {0};
2013
2939
  bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2014
2940
  for (int l = 0; l < k; l++) {
2015
- __builtin_mma_xxsetaccz(&acc_0);
2016
- __builtin_mma_xxsetaccz(&acc_1);
2017
- __builtin_mma_xxsetaccz(&acc_2);
2018
- __builtin_mma_xxsetaccz(&acc_3);
2941
+ __builtin_mma_xxsetaccz(& acc_0);
2942
+ __builtin_mma_xxsetaccz(& acc_1);
2943
+ __builtin_mma_xxsetaccz(& acc_2);
2944
+ __builtin_mma_xxsetaccz(& acc_3);
2019
2945
  if (std::is_same_v<TA, block_q4_0>) {
2020
- packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2946
+ packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
2021
2947
  } else {
2022
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2948
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
2023
2949
  }
2024
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2950
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
2025
2951
  for(int x = 0; x < 8; x++) {
2026
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2027
- __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2028
- __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
2029
- __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
2952
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2953
+ __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
2954
+ __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
2955
+ __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
2030
2956
  }
2031
- for (int I = 0; I<8; I++) {
2032
- for (int J = 0; J<4; J++) {
2033
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2034
- *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2957
+ for (int I = 0; I < 8 ; I++) {
2958
+ for (int J = 0; J < 4; J++) {
2959
+ *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2960
+ *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
2035
2961
  }
2036
2962
  }
2037
2963
  if (!isAblock_q4) {
2038
- auto aoffset = A+(ii*lda)+l;
2964
+ auto aoffset = A + (ii * lda) + l;
2039
2965
  for (int i = 0; i < 8; i++) {
2040
2966
  comparray[i] = 0;
2041
2967
  int ca = 0;
@@ -2046,15 +2972,96 @@ class tinyBLAS_Q0_PPC {
2046
2972
  aoffset += lda;
2047
2973
  }
2048
2974
  }
2049
- compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2050
- compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2051
- compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2052
- compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2975
+ compute(& acc_0, 0, 0, comparray, vs, fin_res);
2976
+ compute(& acc_1, 4, 4, comparray, vs, fin_res);
2977
+ compute(& acc_2, 0, 8, comparray, vs, fin_res);
2978
+ compute(& acc_3, 4, 12, comparray, vs, fin_res);
2053
2979
  }
2054
2980
  save_res(ii, jj, 0, fin_res);
2055
- save_res(ii+4, jj, 4, fin_res);
2056
- save_res(ii, jj+4, 8, fin_res);
2057
- save_res(ii+4, jj+4, 12, fin_res);
2981
+ save_res(ii + 4, jj, 4, fin_res);
2982
+ save_res(ii, jj + 4, 8, fin_res);
2983
+ save_res(ii + 4, jj + 4, 12, fin_res);
2984
+ }
2985
+
2986
+ void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
2987
+ acc_t acc[8];
2988
+ for (int i = 0; i < mc ; i += 16) {
2989
+ for (int j = 0; j < nc; j += 8) {
2990
+ int A0_base = (i / 16) * (2 * 32 * kc);
2991
+ int B0_base = (j / 8) * (32 * kc);
2992
+ for (int x = 0; x < 8; x++) {
2993
+ __builtin_mma_xxsetaccz(&acc[x]);
2994
+ }
2995
+ for (int64_t kk = 0; kk < kc; kk++) {
2996
+ int A0_block_idx = A0_base + kk * 32;
2997
+ int B0_block_idx = B0_base + kk * 32;
2998
+ int A1_block_idx = A0_block_idx + 32 * kc;
2999
+ int B1_block_idx = B0_block_idx + 32 * kc;
3000
+ vec_t * A0_block = & vec_A[A0_block_idx];
3001
+ vec_t * B0_block = & vec_B[B0_block_idx];
3002
+ vec_t * A1_block = & vec_A[A1_block_idx];
3003
+ for (int it = 0; it < 4; it++) {
3004
+ for (int x = 0; x < 4; x++) {
3005
+ __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
3006
+ __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
3007
+ __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
3008
+ __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
3009
+ __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
3010
+ __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
3011
+ __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
3012
+ __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
3013
+ }
3014
+ }
3015
+ }
3016
+ if (l == 0) {
3017
+ save_acc(& acc[0], ii + i, jj + j);
3018
+ save_acc(& acc[1], ii + i, jj + j + 4);
3019
+ save_acc(& acc[2], ii + i + 4, jj + j);
3020
+ save_acc(& acc[3], ii + i + 4, jj + j + 4);
3021
+ save_acc(& acc[4], ii + i + 8, jj + j);
3022
+ save_acc(& acc[5], ii + i + 8, jj + j + 4);
3023
+ save_acc(& acc[6], ii + i + 12, jj + j);
3024
+ save_acc(& acc[7], ii + i + 12, jj + j + 4);
3025
+ } else {
3026
+ add_save_acc(& acc[0], ii + i, jj + j);
3027
+ add_save_acc(& acc[1], ii + i, jj + j + 4);
3028
+ add_save_acc(& acc[2], ii + i + 4, jj + j);
3029
+ add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
3030
+ add_save_acc(& acc[4], ii + i + 8, jj + j);
3031
+ add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
3032
+ add_save_acc(& acc[6], ii + i + 12, jj + j);
3033
+ add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
3034
+ }
3035
+ }
3036
+ }
3037
+ }
3038
+
3039
+ void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
3040
+ vec_t A_pack[mc * kc * 4];
3041
+ vec_t B_pack[nc * kc * 4];
3042
+ constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
3043
+ int64_t ytiles = m / mc;
3044
+ int64_t xtiles = n / nc;
3045
+ int64_t tiles = xtiles * ytiles;
3046
+ int64_t duty = (tiles + nth - 1) / nth;
3047
+ int64_t start = duty * ith;
3048
+ int64_t end = start + duty;
3049
+ if (end > tiles) {
3050
+ end = tiles;
3051
+ }
3052
+ for (int64_t job = start; job < end; ++job) {
3053
+ int64_t ii = (job / xtiles) * mc;
3054
+ int64_t jj = (job % xtiles) * nc;
3055
+ for (int64_t kk = 0; kk < k; kk += kc) {
3056
+ if constexpr(is_Ablock_q4) {
3057
+ packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
3058
+ } else {
3059
+ packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
3060
+ }
3061
+ packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
3062
+ KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
3063
+ }
3064
+ }
2058
3065
  }
2059
3066
 
2060
3067
  void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
@@ -2079,32 +3086,32 @@ class tinyBLAS_Q0_PPC {
2079
3086
  vector float fin_res[4] = {0};
2080
3087
  vector float vs[4] = {0};
2081
3088
  vector float CA[4] = {0};
2082
- __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
2083
- __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
3089
+ __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
3090
+ __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
2084
3091
  for (int l = 0; l < k; l++) {
2085
- __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2086
- __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2087
- __builtin_mma_xxsetaccz(&acc_0);
3092
+ __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
3093
+ __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
3094
+ __builtin_mma_xxsetaccz(& acc_0);
2088
3095
  if (isAblock_q4) {
2089
- packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
3096
+ packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
2090
3097
  } else {
2091
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
3098
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
2092
3099
  }
2093
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2094
- for(int x = 0; x < 8; x+=4) {
2095
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2096
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
2097
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
2098
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
3100
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
3101
+ for (int x = 0; x < 8; x += 4) {
3102
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
3103
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
3104
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
3105
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
2099
3106
  }
2100
- for (int I = 0; I<RM; I++) {
2101
- for (int J = 0; J<RN; J++) {
2102
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
3107
+ for (int I = 0; I < RM; I++) {
3108
+ for (int J = 0; J < RN; J++) {
3109
+ *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2103
3110
  }
2104
3111
  }
2105
- __builtin_mma_disassemble_acc(vec_C, &acc_0);
3112
+ __builtin_mma_disassemble_acc(vec_C, & acc_0);
2106
3113
  if (!isAblock_q4) {
2107
- auto aoffset = A+(ii*lda)+l;
3114
+ auto aoffset = A + (ii * lda) + l;
2108
3115
  for (int i = 0; i < RM; i++) {
2109
3116
  comparray[i] = 0;
2110
3117
  int ca = 0;
@@ -2127,15 +3134,15 @@ class tinyBLAS_Q0_PPC {
2127
3134
 
2128
3135
  template<int RM, int RN>
2129
3136
  inline void kernel(int64_t ii, int64_t jj) {
2130
- if constexpr(RM == 4 && RN == 8) {
2131
- KERNEL_4x8(ii,jj);
2132
- } else if constexpr(RM == 8 && RN == 4) {
2133
- KERNEL_8x4(ii,jj);
2134
- } else if constexpr(RM == 8 && RN == 8) {
2135
- KERNEL_8x8(ii,jj);
2136
- } else {
2137
- assert(false && "RN/RM values not supported");
2138
- }
3137
+ if constexpr(RM == 4 && RN == 8) {
3138
+ KERNEL_4x8(ii,jj);
3139
+ } else if constexpr(RM == 8 && RN == 4) {
3140
+ KERNEL_8x4(ii,jj);
3141
+ } else if constexpr(RM == 8 && RN == 8) {
3142
+ KERNEL_8x8(ii,jj);
3143
+ } else {
3144
+ assert(false && "RN/RM values not supported");
3145
+ }
2139
3146
  }
2140
3147
 
2141
3148
  template <int RM, int RN>
@@ -2154,11 +3161,11 @@ class tinyBLAS_Q0_PPC {
2154
3161
  kernel<RM, RN>(ii, jj);
2155
3162
  }
2156
3163
  }
2157
-
2158
- const TA *const A;
2159
- const block_q8_0 *const B;
2160
- float *C;
3164
+ const TA * const A;
3165
+ const block_q8_0 * const B;
3166
+ float * C;
2161
3167
  const int64_t k;
3168
+ int64_t kc;
2162
3169
  const int64_t lda;
2163
3170
  const int64_t ldb;
2164
3171
  const int64_t ldc;
@@ -2731,6 +3738,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2731
3738
  params->ith, params->nth};
2732
3739
  tb.matmul(m, n);
2733
3740
  return true;
3741
+ #elif defined(__riscv_zvfh)
3742
+ #if LMUL == 1
3743
+ tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
3744
+ k, (const float *)A, lda,
3745
+ (const float *)B, ldb,
3746
+ (float *)C, ldc};
3747
+ #elif LMUL == 2
3748
+ tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
3749
+ k, (const float *)A, lda,
3750
+ (const float *)B, ldb,
3751
+ (float *)C, ldc};
3752
+ #else // LMUL = 4
3753
+ tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
3754
+ k, (const float *)A, lda,
3755
+ (const float *)B, ldb,
3756
+ (float *)C, ldc};
3757
+ #endif
3758
+ return tb.matmul(m, n);
2734
3759
  #else
2735
3760
  return false;
2736
3761
  #endif
@@ -2762,17 +3787,38 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2762
3787
  return tb.matmul(m, n);
2763
3788
  }
2764
3789
  #elif defined(__MMA__)
2765
- if ((k % 8))
2766
- return false;
2767
- if(Btype == GGML_TYPE_BF16) {
2768
- tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
2769
- (const ggml_bf16_t *)A, lda,
2770
- (const ggml_bf16_t *)B, ldb,
2771
- (float *)C, ldc,
2772
- params->ith, params->nth};
2773
- tb.matmul(m, n);
2774
- return true;
3790
+ if (k % 8) {
3791
+ return false;
3792
+ }
3793
+
3794
+ if (Btype == GGML_TYPE_BF16) {
3795
+ tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3796
+ (const ggml_bf16_t *)A, lda,
3797
+ (const ggml_bf16_t *)B, ldb,
3798
+ (float *)C, ldc,
3799
+ params->ith, params->nth };
3800
+
3801
+ tb.matmul(m, n);
3802
+ return true;
2775
3803
  }
3804
+ #elif defined(__riscv_zvfbfwma)
3805
+ #if LMUL == 1
3806
+ tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3807
+ k, (const ggml_bf16_t *)A, lda,
3808
+ (const ggml_bf16_t *)B, ldb,
3809
+ (float *)C, ldc};
3810
+ #elif LMUL == 2
3811
+ tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3812
+ k, (const ggml_bf16_t *)A, lda,
3813
+ (const ggml_bf16_t *)B, ldb,
3814
+ (float *)C, ldc};
3815
+ #else // LMUL = 4
3816
+ tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3817
+ k, (const ggml_bf16_t *)A, lda,
3818
+ (const ggml_bf16_t *)B, ldb,
3819
+ (float *)C, ldc};
3820
+ #endif
3821
+ return tb.matmul(m, n);
2776
3822
  #endif
2777
3823
  return false;
2778
3824
  }
@@ -2822,6 +3868,41 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2822
3868
  (float *)C, ldc};
2823
3869
  return tb.matmul(m, n);
2824
3870
  }
3871
+ #elif defined(__riscv_zvfh)
3872
+ if (Btype == GGML_TYPE_F16) {
3873
+ #if LMUL == 1
3874
+ tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3875
+ k, (const ggml_fp16_t *)A, lda,
3876
+ (const ggml_fp16_t *)B, ldb,
3877
+ (float *)C, ldc};
3878
+ #elif LMUL == 2
3879
+ tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3880
+ k, (const ggml_fp16_t *)A, lda,
3881
+ (const ggml_fp16_t *)B, ldb,
3882
+ (float *)C, ldc};
3883
+ #else // LMUL = 4
3884
+ tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3885
+ k, (const ggml_fp16_t *)A, lda,
3886
+ (const ggml_fp16_t *)B, ldb,
3887
+ (float *)C, ldc};
3888
+ #endif
3889
+ return tb.matmul(m, n);
3890
+ }
3891
+ #elif defined(__MMA__)
3892
+ if (k % 8) {
3893
+ return false;
3894
+ }
3895
+
3896
+ if (Btype == GGML_TYPE_F16) {
3897
+ tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
3898
+ (const ggml_fp16_t *)A, lda,
3899
+ (const ggml_fp16_t *)B, ldb,
3900
+ (float *)C, ldc,
3901
+ params->ith, params->nth };
3902
+
3903
+ tb.matmul(m, n);
3904
+ return true;
3905
+ }
2825
3906
  #endif
2826
3907
  return false;
2827
3908
  }