whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -69,6 +69,10 @@
69
69
  #define VECTOR_REGISTERS 16
70
70
  #endif
71
71
 
72
+ #if defined(__riscv_v_intrinsic)
73
+ #define LMUL 4
74
+ #endif
75
+
72
76
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
73
77
 
74
78
  namespace {
@@ -117,8 +121,7 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
117
121
  #endif
118
122
 
119
123
  #if defined(__MMA__)
120
- typedef vector unsigned char vec_t;
121
- typedef __vector_quad acc_t;
124
+ #include "sgemm-ppc.h"
122
125
  #endif
123
126
  ////////////////////////////////////////////////////////////////////////////////////////////////////
124
127
  // VECTORIZED FUSED MULTIPLY ADD
@@ -176,6 +179,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
176
179
  }
177
180
  #endif
178
181
 
182
+ #if defined(__riscv_zvfh)
183
+ template <>
184
+ inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
185
+ return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
186
+ }
187
+ inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
188
+ return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
189
+ }
190
+ inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
191
+ return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
192
+ }
193
+ inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
194
+ return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
195
+ }
196
+ inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
197
+ return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
198
+ }
199
+ inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
200
+ return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
201
+ }
202
+ inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
203
+ return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
204
+ }
205
+ inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
206
+ return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
207
+ }
208
+ #endif
209
+
210
+ #if defined(__riscv_zvfbfwma)
211
+ inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
212
+ return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
213
+ }
214
+ inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
215
+ return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
216
+ }
217
+ inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
218
+ return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
219
+ }
220
+ #endif
221
+
179
222
  ////////////////////////////////////////////////////////////////////////////////////////////////////
180
223
  // VECTORIZED HORIZONTAL SUM
181
224
 
@@ -228,6 +271,25 @@ inline float hsum(__m512 x) {
228
271
  }
229
272
  #endif // __AVX512F__
230
273
 
274
+ #if defined(__riscv_zvfh)
275
+ inline float hsum(vfloat32m1_t x) {
276
+ return __riscv_vfmv_f_s_f32m1_f32(
277
+ __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
278
+ }
279
+ inline float hsum(vfloat32m2_t x) {
280
+ return __riscv_vfmv_f_s_f32m1_f32(
281
+ __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
282
+ }
283
+ inline float hsum(vfloat32m4_t x) {
284
+ return __riscv_vfmv_f_s_f32m1_f32(
285
+ __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
286
+ }
287
+ inline float hsum(vfloat32m8_t x) {
288
+ return __riscv_vfmv_f_s_f32m1_f32(
289
+ __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
290
+ }
291
+ #endif
292
+
231
293
  ////////////////////////////////////////////////////////////////////////////////////////////////////
232
294
  // VECTORIZED MEMORY LOADING
233
295
 
@@ -316,6 +378,88 @@ template <> inline __m256bh load(const float *p) {
316
378
  }
317
379
  #endif
318
380
 
381
+ #if defined(__riscv_zvfh)
382
+ template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
383
+ return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
384
+ }
385
+ template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
386
+ return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
387
+ }
388
+ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
389
+ return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
390
+ }
391
+ template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
392
+ return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
393
+ }
394
+ template <> inline vfloat32m1_t load(const float *p) {
395
+ return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
396
+ }
397
+ template <> inline vfloat32m2_t load(const float *p) {
398
+ return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
399
+ }
400
+ template <> inline vfloat32m4_t load(const float *p) {
401
+ return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
402
+ }
403
+ template <> inline vfloat32m8_t load(const float *p) {
404
+ return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
405
+ }
406
+ #endif
407
+
408
+ #if defined(__riscv_zvfbfwma)
409
+ template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
410
+ return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
411
+ }
412
+ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
413
+ return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
414
+ }
415
+ template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
416
+ return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
417
+ }
418
+ #endif
419
+
420
+ #if defined(__riscv_zvfh)
421
+ template <typename T> T set_zero();
422
+
423
+ template <> inline vfloat16mf2_t set_zero() {
424
+ return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
425
+ }
426
+ template <> inline vfloat16m1_t set_zero() {
427
+ return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
428
+ }
429
+ template <> inline vfloat16m2_t set_zero() {
430
+ return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
431
+ }
432
+ template <> inline vfloat16m4_t set_zero() {
433
+ return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
434
+ }
435
+ template <> inline vfloat32m1_t set_zero() {
436
+ return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
437
+ }
438
+ template <> inline vfloat32m2_t set_zero() {
439
+ return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
440
+ }
441
+ template <> inline vfloat32m4_t set_zero() {
442
+ return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
443
+ }
444
+ template <> inline vfloat32m8_t set_zero() {
445
+ return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
446
+ }
447
+ #endif
448
+
449
+ #if defined(__riscv_v_intrinsic)
450
+ template <typename T> size_t vlmax() {
451
+ if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
452
+ else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
453
+ else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
454
+ else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
455
+ else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
456
+ else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
457
+ else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
458
+ else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
459
+ return 0;
460
+ }
461
+ #endif
462
+
319
463
  ////////////////////////////////////////////////////////////////////////////////////////////////////
320
464
  // FLOATING POINT MATRIX MULTIPLICATION
321
465
 
@@ -489,6 +633,573 @@ class tinyBLAS {
489
633
  const int64_t ldc;
490
634
  };
491
635
 
636
+ #if defined(__riscv_v_intrinsic)
637
+ template <typename D, typename V, typename TA, typename TB, typename TC>
638
+ class tinyBLAS_RVV {
639
+ public:
640
+ tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
641
+ const TA *A, int64_t lda,
642
+ const TB *B, int64_t ldb,
643
+ TC *C, int64_t ldc)
644
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
645
+ }
646
+
647
+ bool matmul(int64_t m, int64_t n) {
648
+ if (k % vlmax<V>() != 0) {
649
+ return false;
650
+ }
651
+
652
+ #if LMUL == 1
653
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
654
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
655
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
656
+ return true;
657
+ }
658
+ if (m % 8 == 0 ) {
659
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
660
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
661
+ return true;
662
+ }
663
+ if (m % 4 == 0) {
664
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
665
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
666
+ return true;
667
+ }
668
+ #elif LMUL == 2
669
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
670
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
671
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
672
+ return true;
673
+ }
674
+ if (m % 8 == 0 ) {
675
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
676
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
677
+ return true;
678
+ }
679
+ if (m % 4 == 0) {
680
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
681
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
682
+ return true;
683
+ }
684
+ #else // LMUL = 4
685
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
686
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
687
+ mnpack<2, 2, 8>(m, n, SIZE_N, 36);
688
+ return true;
689
+ }
690
+ if (m % 8 == 0 ) {
691
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
692
+ mnpack<2, 2, 4>(m, n, SIZE_N, 36);
693
+ return true;
694
+ }
695
+ if (m % 4 == 0) {
696
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
697
+ mnpack<2, 2, 2>(m, n, SIZE_N, 36);
698
+ return true;
699
+ }
700
+ #endif
701
+ return false;
702
+ }
703
+
704
+ private:
705
+ template<int RM, int RN, int BM>
706
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
707
+ if (SIZE_N == RN) {
708
+ return gemm<RM, RN, BM>(m, n, BN);
709
+ }
710
+ if constexpr (RN > 1) {
711
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
712
+ } else {
713
+ GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
714
+ GGML_ASSERT(false); // we have miss something.
715
+ }
716
+ }
717
+
718
+ inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
719
+ size_t vl = vlmax<V>();
720
+ D Cv00 = set_zero<D>();
721
+ D Cv01 = set_zero<D>();
722
+ D Cv02 = set_zero<D>();
723
+ D Cv03 = set_zero<D>();
724
+ D Cv10 = set_zero<D>();
725
+ D Cv11 = set_zero<D>();
726
+ D Cv12 = set_zero<D>();
727
+ D Cv13 = set_zero<D>();
728
+ D Cv20 = set_zero<D>();
729
+ D Cv21 = set_zero<D>();
730
+ D Cv22 = set_zero<D>();
731
+ D Cv23 = set_zero<D>();
732
+ D Cv30 = set_zero<D>();
733
+ D Cv31 = set_zero<D>();
734
+ D Cv32 = set_zero<D>();
735
+ D Cv33 = set_zero<D>();
736
+ D Cv40 = set_zero<D>();
737
+ D Cv41 = set_zero<D>();
738
+ D Cv42 = set_zero<D>();
739
+ D Cv43 = set_zero<D>();
740
+ D Cv50 = set_zero<D>();
741
+ D Cv51 = set_zero<D>();
742
+ D Cv52 = set_zero<D>();
743
+ D Cv53 = set_zero<D>();
744
+
745
+ for (int64_t l = 0; l < k; l += vl) {
746
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
747
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
748
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
749
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
750
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
751
+ V Bv5 = load<V>(B + ldb * (jj + 5) + l);
752
+
753
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
754
+ Cv00 = madd(Av0, Bv0, Cv00);
755
+ Cv10 = madd(Av0, Bv1, Cv10);
756
+ Cv20 = madd(Av0, Bv2, Cv20);
757
+ Cv30 = madd(Av0, Bv3, Cv30);
758
+ Cv40 = madd(Av0, Bv4, Cv40);
759
+ Cv50 = madd(Av0, Bv5, Cv50);
760
+
761
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
762
+ Cv01 = madd(Av1, Bv0, Cv01);
763
+ Cv11 = madd(Av1, Bv1, Cv11);
764
+ Cv21 = madd(Av1, Bv2, Cv21);
765
+ Cv31 = madd(Av1, Bv3, Cv31);
766
+ Cv41 = madd(Av1, Bv4, Cv41);
767
+ Cv51 = madd(Av1, Bv5, Cv51);
768
+
769
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
770
+ Cv02 = madd(Av2, Bv0, Cv02);
771
+ Cv12 = madd(Av2, Bv1, Cv12);
772
+ Cv22 = madd(Av2, Bv2, Cv22);
773
+ Cv32 = madd(Av2, Bv3, Cv32);
774
+ Cv42 = madd(Av2, Bv4, Cv42);
775
+ Cv52 = madd(Av2, Bv5, Cv52);
776
+
777
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
778
+ Cv03 = madd(Av3, Bv0, Cv03);
779
+ Cv13 = madd(Av3, Bv1, Cv13);
780
+ Cv23 = madd(Av3, Bv2, Cv23);
781
+ Cv33 = madd(Av3, Bv3, Cv33);
782
+ Cv43 = madd(Av3, Bv4, Cv43);
783
+ Cv53 = madd(Av3, Bv5, Cv53);
784
+ }
785
+
786
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
787
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
788
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
789
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
790
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
791
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
792
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
793
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
794
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
795
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
796
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
797
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
798
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
799
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
800
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
801
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
802
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
803
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
804
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
805
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
806
+ C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
807
+ C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
808
+ C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
809
+ C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
810
+ }
811
+
812
+ inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
813
+ size_t vl = vlmax<V>();
814
+ D Cv00 = set_zero<D>();
815
+ D Cv01 = set_zero<D>();
816
+ D Cv02 = set_zero<D>();
817
+ D Cv03 = set_zero<D>();
818
+ D Cv10 = set_zero<D>();
819
+ D Cv11 = set_zero<D>();
820
+ D Cv12 = set_zero<D>();
821
+ D Cv13 = set_zero<D>();
822
+ D Cv20 = set_zero<D>();
823
+ D Cv21 = set_zero<D>();
824
+ D Cv22 = set_zero<D>();
825
+ D Cv23 = set_zero<D>();
826
+ D Cv30 = set_zero<D>();
827
+ D Cv31 = set_zero<D>();
828
+ D Cv32 = set_zero<D>();
829
+ D Cv33 = set_zero<D>();
830
+ D Cv40 = set_zero<D>();
831
+ D Cv41 = set_zero<D>();
832
+ D Cv42 = set_zero<D>();
833
+ D Cv43 = set_zero<D>();
834
+
835
+ for (int64_t l = 0; l < k; l += vl) {
836
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
837
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
838
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
839
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
840
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
841
+
842
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
843
+ Cv00 = madd(Av0, Bv0, Cv00);
844
+ Cv10 = madd(Av0, Bv1, Cv10);
845
+ Cv20 = madd(Av0, Bv2, Cv20);
846
+ Cv30 = madd(Av0, Bv3, Cv30);
847
+ Cv40 = madd(Av0, Bv4, Cv40);
848
+
849
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
850
+ Cv01 = madd(Av1, Bv0, Cv01);
851
+ Cv11 = madd(Av1, Bv1, Cv11);
852
+ Cv21 = madd(Av1, Bv2, Cv21);
853
+ Cv31 = madd(Av1, Bv3, Cv31);
854
+ Cv41 = madd(Av1, Bv4, Cv41);
855
+
856
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
857
+ Cv02 = madd(Av2, Bv0, Cv02);
858
+ Cv12 = madd(Av2, Bv1, Cv12);
859
+ Cv22 = madd(Av2, Bv2, Cv22);
860
+ Cv32 = madd(Av2, Bv3, Cv32);
861
+ Cv42 = madd(Av2, Bv4, Cv42);
862
+
863
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
864
+ Cv03 = madd(Av3, Bv0, Cv03);
865
+ Cv13 = madd(Av3, Bv1, Cv13);
866
+ Cv23 = madd(Av3, Bv2, Cv23);
867
+ Cv33 = madd(Av3, Bv3, Cv33);
868
+ Cv43 = madd(Av3, Bv4, Cv43);
869
+ }
870
+
871
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
872
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
873
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
874
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
875
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
876
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
877
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
878
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
879
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
880
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
881
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
882
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
883
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
884
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
885
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
886
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
887
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
888
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
889
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
890
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
891
+ }
892
+
893
+ inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
894
+ size_t vl = vlmax<V>();
895
+ D Cv00 = set_zero<D>();
896
+ D Cv01 = set_zero<D>();
897
+ D Cv02 = set_zero<D>();
898
+ D Cv03 = set_zero<D>();
899
+ D Cv10 = set_zero<D>();
900
+ D Cv11 = set_zero<D>();
901
+ D Cv12 = set_zero<D>();
902
+ D Cv13 = set_zero<D>();
903
+ D Cv20 = set_zero<D>();
904
+ D Cv21 = set_zero<D>();
905
+ D Cv22 = set_zero<D>();
906
+ D Cv23 = set_zero<D>();
907
+ D Cv30 = set_zero<D>();
908
+ D Cv31 = set_zero<D>();
909
+ D Cv32 = set_zero<D>();
910
+ D Cv33 = set_zero<D>();
911
+
912
+ for (int64_t l = 0; l < k; l += vl) {
913
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
914
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
915
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
916
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
917
+
918
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
919
+ Cv00 = madd(Av0, Bv0, Cv00);
920
+ Cv01 = madd(Av1, Bv0, Cv01);
921
+ Cv02 = madd(Av2, Bv0, Cv02);
922
+ Cv03 = madd(Av3, Bv0, Cv03);
923
+
924
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
925
+ Cv10 = madd(Av0, Bv1, Cv10);
926
+ Cv11 = madd(Av1, Bv1, Cv11);
927
+ Cv12 = madd(Av2, Bv1, Cv12);
928
+ Cv13 = madd(Av3, Bv1, Cv13);
929
+
930
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
931
+ Cv20 = madd(Av0, Bv2, Cv20);
932
+ Cv21 = madd(Av1, Bv2, Cv21);
933
+ Cv22 = madd(Av2, Bv2, Cv22);
934
+ Cv23 = madd(Av3, Bv2, Cv23);
935
+
936
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
937
+ Cv30 = madd(Av0, Bv3, Cv30);
938
+ Cv31 = madd(Av1, Bv3, Cv31);
939
+ Cv32 = madd(Av2, Bv3, Cv32);
940
+ Cv33 = madd(Av3, Bv3, Cv33);
941
+ }
942
+
943
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
944
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
945
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
946
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
947
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
948
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
949
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
950
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
951
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
952
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
953
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
954
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
955
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
956
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
957
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
958
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
959
+ }
960
+
961
+ inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
962
+ size_t vl = vlmax<V>();
963
+ D Cv00 = set_zero<D>();
964
+ D Cv01 = set_zero<D>();
965
+ D Cv02 = set_zero<D>();
966
+ D Cv03 = set_zero<D>();
967
+ D Cv10 = set_zero<D>();
968
+ D Cv11 = set_zero<D>();
969
+ D Cv12 = set_zero<D>();
970
+ D Cv13 = set_zero<D>();
971
+ D Cv20 = set_zero<D>();
972
+ D Cv21 = set_zero<D>();
973
+ D Cv22 = set_zero<D>();
974
+ D Cv23 = set_zero<D>();
975
+
976
+ for (int64_t l = 0; l < k; l += vl) {
977
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
978
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
979
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
980
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
981
+
982
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
983
+ Cv00 = madd(Av0, Bv0, Cv00);
984
+ Cv01 = madd(Av1, Bv0, Cv01);
985
+ Cv02 = madd(Av2, Bv0, Cv02);
986
+ Cv03 = madd(Av3, Bv0, Cv03);
987
+
988
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
989
+ Cv10 = madd(Av0, Bv1, Cv10);
990
+ Cv11 = madd(Av1, Bv1, Cv11);
991
+ Cv12 = madd(Av2, Bv1, Cv12);
992
+ Cv13 = madd(Av3, Bv1, Cv13);
993
+
994
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
995
+ Cv20 = madd(Av0, Bv2, Cv20);
996
+ Cv21 = madd(Av1, Bv2, Cv21);
997
+ Cv22 = madd(Av2, Bv2, Cv22);
998
+ Cv23 = madd(Av3, Bv2, Cv23);
999
+ }
1000
+
1001
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1002
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1003
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1004
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1005
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1006
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1007
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1008
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1009
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
1010
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
1011
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
1012
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
1013
+ }
1014
+
1015
+ inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
1016
+ size_t vl = vlmax<V>();
1017
+ D Cv00 = set_zero<D>();
1018
+ D Cv01 = set_zero<D>();
1019
+ D Cv02 = set_zero<D>();
1020
+ D Cv03 = set_zero<D>();
1021
+ D Cv10 = set_zero<D>();
1022
+ D Cv11 = set_zero<D>();
1023
+ D Cv12 = set_zero<D>();
1024
+ D Cv13 = set_zero<D>();
1025
+
1026
+ for (int64_t l = 0; l < k; l += vl) {
1027
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1028
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1029
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1030
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1031
+
1032
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1033
+ Cv00 = madd(Av0, Bv0, Cv00);
1034
+ Cv01 = madd(Av1, Bv0, Cv01);
1035
+ Cv02 = madd(Av2, Bv0, Cv02);
1036
+ Cv03 = madd(Av3, Bv0, Cv03);
1037
+
1038
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1039
+ Cv10 = madd(Av0, Bv1, Cv10);
1040
+ Cv11 = madd(Av1, Bv1, Cv11);
1041
+ Cv12 = madd(Av2, Bv1, Cv12);
1042
+ Cv13 = madd(Av3, Bv1, Cv13);
1043
+ }
1044
+
1045
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1046
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1047
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1048
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1049
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1050
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1051
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1052
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1053
+ }
1054
+
1055
+ inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
1056
+ size_t vl = vlmax<V>();
1057
+ D Cv00 = set_zero<D>();
1058
+ D Cv01 = set_zero<D>();
1059
+ D Cv02 = set_zero<D>();
1060
+ D Cv03 = set_zero<D>();
1061
+
1062
+ for (int64_t l = 0; l < k; l += vl) {
1063
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1064
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1065
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
1066
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
1067
+
1068
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1069
+ Cv00 = madd(Av0, Bv0, Cv00);
1070
+ Cv01 = madd(Av1, Bv0, Cv01);
1071
+ Cv02 = madd(Av2, Bv0, Cv02);
1072
+ Cv03 = madd(Av3, Bv0, Cv03);
1073
+ }
1074
+
1075
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1076
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1077
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1078
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1079
+ }
1080
+
1081
+ inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
1082
+ size_t vl = vlmax<V>();
1083
+ D Cv00 = set_zero<D>();
1084
+ D Cv01 = set_zero<D>();
1085
+ D Cv10 = set_zero<D>();
1086
+ D Cv11 = set_zero<D>();
1087
+
1088
+ for (int64_t l = 0; l < k; l += vl) {
1089
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1090
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1091
+
1092
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1093
+ Cv00 = madd(Av0, Bv0, Cv00);
1094
+ Cv01 = madd(Av1, Bv0, Cv01);
1095
+
1096
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1097
+ Cv10 = madd(Av0, Bv1, Cv10);
1098
+ Cv11 = madd(Av1, Bv1, Cv11);
1099
+ }
1100
+
1101
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1102
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1103
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1104
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1105
+ }
1106
+
1107
+ inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
1108
+ size_t vl = vlmax<V>();
1109
+ D Cv00 = set_zero<D>();
1110
+ D Cv01 = set_zero<D>();
1111
+
1112
+ for (int64_t l = 0; l < k; l += vl) {
1113
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
1114
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
1115
+
1116
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1117
+ Cv00 = madd(Av0, Bv0, Cv00);
1118
+ Cv01 = madd(Av1, Bv0, Cv01);
1119
+ }
1120
+
1121
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1122
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1123
+ }
1124
+
1125
+ template <int RM, int RN>
1126
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
1127
+ if constexpr (RM == 4) {
1128
+ if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
1129
+ if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
1130
+ if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
1131
+ if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
1132
+ if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
1133
+ if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
1134
+ } else if constexpr (RM == 2) {
1135
+ if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
1136
+ if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
1137
+ }
1138
+ }
1139
+
1140
+ template <int RM, int RN, int BM>
1141
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
1142
+ GGML_ASSERT(m % (RM * BM) == 0);
1143
+ const int64_t ytiles = m / (RM * BM);
1144
+ const int64_t xtiles = (n + RN -1) / RN;
1145
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
1146
+
1147
+ // "round" bloc_size to "nearest" BN
1148
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
1149
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
1150
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
1151
+ const int64_t nb_job = ytiles * NB_BN;
1152
+
1153
+ if (params->ith == 0) {
1154
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
1155
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1156
+ ggml_threadpool_chunk_set(params->threadpool, params->nth);
1157
+ }
1158
+
1159
+ ggml_barrier(params->threadpool);
1160
+
1161
+ int64_t job = params->ith;
1162
+ while (job < nb_job) {
1163
+ const int64_t ii = (job % ytiles) * RM * BM;
1164
+ const int64_t jb = job / ytiles;
1165
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
1166
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
1167
+
1168
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
1169
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
1170
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
1171
+
1172
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
1173
+ int64_t jj = jj0;
1174
+ for (; jj < jj1; jj += RN) {
1175
+ gemm_bloc<RM, RN>(ii + bi, jj);
1176
+ }
1177
+ if constexpr (RN > 1) {
1178
+ for (; jj < jj2; jj += RN - 1) {
1179
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
1180
+ }
1181
+ }
1182
+ GGML_ASSERT(jj == jj2);
1183
+ }
1184
+
1185
+ job = ggml_threadpool_chunk_add(params->threadpool, 1);
1186
+ }
1187
+
1188
+ ggml_barrier(params->threadpool);
1189
+ return;
1190
+ }
1191
+
1192
+ const ggml_compute_params * params;
1193
+ const TA *const A;
1194
+ const TB *const B;
1195
+ TC *const C;
1196
+ const int64_t k;
1197
+ const int64_t lda;
1198
+ const int64_t ldb;
1199
+ const int64_t ldc;
1200
+ };
1201
+ #endif
1202
+
492
1203
  //////////////////////////////////////////////////////////////////////////////////////////
493
1204
  // QUANT ZERO MATRIX MULTIPLICATION
494
1205
 
@@ -1541,7 +2252,7 @@ class tinyBLAS_BF16_PPC {
1541
2252
  } else if constexpr(RM == 8 && RN == 4) {
1542
2253
  KERNEL_8x4(ii,jj);
1543
2254
  } else {
1544
- static_assert(false, "RN/RM values not supported");
2255
+ assert(false && "RN/RM values not supported");
1545
2256
  }
1546
2257
  }
1547
2258
 
@@ -1573,67 +2284,44 @@ class tinyBLAS_BF16_PPC {
1573
2284
  const int nth;
1574
2285
  };
1575
2286
 
1576
- template <typename TA, typename TB, typename TC>
1577
- class tinyBLAS_Q0_PPC {
1578
- public:
1579
- tinyBLAS_Q0_PPC(int64_t k,
1580
- const TA *A, int64_t lda,
1581
- const TB *B, int64_t ldb,
1582
- TC *C, int64_t ldc,
1583
- int ith, int nth)
2287
+ template <typename TA>
2288
+ tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
2289
+ const TA *A, int64_t lda,
2290
+ const block_q8_0 *B, int64_t ldb,
2291
+ float *C, int64_t ldc,
2292
+ int ith, int nth)
1584
2293
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2294
+ kc = 64;
1585
2295
  }
1586
2296
 
1587
- void matmul(int64_t m, int64_t n) {
1588
- mnpack(0, m, 0, n);
1589
- }
1590
-
1591
- private:
1592
-
1593
- template<int RM, int RN>
1594
- inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1595
- for (int I = 0; I < RM; I++) {
1596
- for (int J = 0; J < RN; J++) {
1597
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1598
- }
1599
- }
1600
- }
1601
-
1602
- template<int size>
1603
- inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1604
- vector signed int vec_C[4];
1605
- vector float CA[4] = {0};
1606
- vector float res[4] = {0};
1607
- __builtin_mma_disassemble_acc(vec_C, ACC);
1608
- for (int i = 0; i < 4; i++) {
1609
- CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1610
- res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1611
- fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1612
- }
2297
+ template<typename TA>
2298
+ void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
2299
+ int mc = 64; int nc = 64;
2300
+ if (n % 8 == 0 && n < nc) {
2301
+ nc = n;
2302
+ mc = 32 ;
2303
+ kc = 32;
2304
+ }
2305
+ const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
2306
+ if (is_aligned) {
2307
+ this->matmul_tiled_q0(m, n, mc, nc, kc);
2308
+ } else {
2309
+ mnpack(0, m, 0, n);
2310
+ }
1613
2311
  }
1614
2312
 
1615
- template<typename VA, typename VB, int size>
1616
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
2313
+ template<typename TA>
2314
+ template<int size>
2315
+ void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1617
2316
  int64_t i, j;
1618
2317
  TA *aoffset = NULL;
1619
- VA *vecOffset = NULL;
2318
+ int8_t *vecOffset = NULL;
1620
2319
  TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1621
2320
  TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1622
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1623
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1624
- VB t1, t2, t3, t4, t5, t6, t7, t8;
1625
- const vector signed char lowMask = vec_splats((signed char)0xF);
1626
- const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1627
- const vector signed char v8 = vec_splats((signed char)0x8);
2321
+ vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2322
+ vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1628
2323
  aoffset = const_cast<TA*>(a);
1629
2324
  vecOffset = vec;
1630
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1631
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1632
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1633
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1634
- vector signed int vsum = {0};
1635
- vector signed int vsum2 = {0};
1636
-
1637
2325
  j = (rows >> 3);
1638
2326
  if (j > 0) {
1639
2327
  do {
@@ -1646,159 +2334,30 @@ class tinyBLAS_Q0_PPC {
1646
2334
  aoffset7 = aoffset6 + lda;
1647
2335
  aoffset8 = aoffset7 + lda;
1648
2336
  aoffset += 8 * lda;
1649
-
1650
2337
  i = (cols >> 2);
1651
2338
  if (i > 0) {
1652
2339
  do {
1653
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1654
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1655
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1656
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1657
- c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1658
- c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1659
- c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1660
- c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1661
-
1662
- c1[0] = vec_and(c1[1], lowMask);
1663
- c1[1] = vec_sr(c1[1], v4);
1664
- c1[0] = vec_sub(c1[0], v8);
1665
- c1[1] = vec_sub(c1[1], v8);
1666
- vsum = vec_sum4s(c1[0], vsum);
1667
- vsum2 = vec_sum4s(c1[1], vsum2);
1668
- vsum = vec_add(vsum, vsum2);
1669
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1670
- vsum = vec_splats(0);
1671
- vsum2 = vec_splats(0);
1672
-
1673
- c2[0] = vec_and(c2[1], lowMask);
1674
- c2[1] = vec_sr(c2[1], v4);
1675
- c2[0] = vec_sub(c2[0], v8);
1676
- c2[1] = vec_sub(c2[1], v8);
1677
- vsum = vec_sum4s(c2[0], vsum);
1678
- vsum2 = vec_sum4s(c2[1], vsum2);
1679
- vsum = vec_add(vsum, vsum2);
1680
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1681
- vsum = vec_splats(0);
1682
- vsum2 = vec_splats(0);
1683
-
1684
- c3[0] = vec_and(c3[1], lowMask);
1685
- c3[1] = vec_sr(c3[1], v4);
1686
- c3[0] = vec_sub(c3[0], v8);
1687
- c3[1] = vec_sub(c3[1], v8);
1688
- vsum = vec_sum4s(c3[0], vsum);
1689
- vsum2 = vec_sum4s(c3[1], vsum2);
1690
- vsum = vec_add(vsum, vsum2);
1691
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1692
- vsum = vec_splats(0);
1693
- vsum2 = vec_splats(0);
1694
-
1695
- c4[0] = vec_and(c4[1], lowMask);
1696
- c4[1] = vec_sr(c4[1], v4);
1697
- c4[0] = vec_sub(c4[0], v8);
1698
- c4[1] = vec_sub(c4[1], v8);
1699
- vsum = vec_sum4s(c4[0], vsum);
1700
- vsum2 = vec_sum4s(c4[1], vsum2);
1701
- vsum = vec_add(vsum, vsum2);
1702
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1703
- vsum = vec_splats(0);
1704
- vsum2 = vec_splats(0);
1705
-
1706
- c5[0] = vec_and(c5[1], lowMask);
1707
- c5[1] = vec_sr(c5[1], v4);
1708
- c5[0] = vec_sub(c5[0], v8);
1709
- c5[1] = vec_sub(c5[1], v8);
1710
- vsum = vec_sum4s(c5[0], vsum);
1711
- vsum2 = vec_sum4s(c5[1], vsum2);
1712
- vsum = vec_add(vsum, vsum2);
1713
- comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1714
- vsum = vec_splats(0);
1715
- vsum2 = vec_splats(0);
1716
-
1717
- c6[0] = vec_and(c6[1], lowMask);
1718
- c6[1] = vec_sr(c6[1], v4);
1719
- c6[0] = vec_sub(c6[0], v8);
1720
- c6[1] = vec_sub(c6[1], v8);
1721
- vsum = vec_sum4s(c6[0], vsum);
1722
- vsum2 = vec_sum4s(c6[1], vsum2);
1723
- vsum = vec_add(vsum, vsum2);
1724
- comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1725
- vsum = vec_splats(0);
1726
- vsum2 = vec_splats(0);
1727
-
1728
- c7[0] = vec_and(c7[1], lowMask);
1729
- c7[1] = vec_sr(c7[1], v4);
1730
- c7[0] = vec_sub(c7[0], v8);
1731
- c7[1] = vec_sub(c7[1], v8);
1732
- vsum = vec_sum4s(c7[0], vsum);
1733
- vsum2 = vec_sum4s(c7[1], vsum2);
1734
- vsum = vec_add(vsum, vsum2);
1735
- comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1736
- vsum = vec_splats(0);
1737
- vsum2 = vec_splats(0);
1738
-
1739
- c8[0] = vec_and(c8[1], lowMask);
1740
- c8[1] = vec_sr(c8[1], v4);
1741
- c8[0] = vec_sub(c8[0], v8);
1742
- c8[1] = vec_sub(c8[1], v8);
1743
- vsum = vec_sum4s(c8[0], vsum);
1744
- vsum2 = vec_sum4s(c8[1], vsum2);
1745
- vsum = vec_add(vsum, vsum2);
1746
- comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1747
- vsum = vec_splats(0);
1748
- vsum2 = vec_splats(0);
1749
-
1750
- t1 = vec_perm(c1[0], c2[0], swiz1);
1751
- t2 = vec_perm(c1[0], c2[0], swiz2);
1752
- t3 = vec_perm(c3[0], c4[0], swiz1);
1753
- t4 = vec_perm(c3[0], c4[0], swiz2);
1754
- t5 = vec_perm(t1, t3, swiz3);
1755
- t6 = vec_perm(t1, t3, swiz4);
1756
- t7 = vec_perm(t2, t4, swiz3);
1757
- t8 = vec_perm(t2, t4, swiz4);
1758
- vec_xst(t5, 0, vecOffset);
1759
- vec_xst(t6, 0, vecOffset+16);
1760
- vec_xst(t7, 0, vecOffset+32);
1761
- vec_xst(t8, 0, vecOffset+48);
1762
-
1763
- t1 = vec_perm(c1[1], c2[1], swiz1);
1764
- t2 = vec_perm(c1[1], c2[1], swiz2);
1765
- t3 = vec_perm(c3[1], c4[1], swiz1);
1766
- t4 = vec_perm(c3[1], c4[1], swiz2);
1767
- t5 = vec_perm(t1, t3, swiz3);
1768
- t6 = vec_perm(t1, t3, swiz4);
1769
- t7 = vec_perm(t2, t4, swiz3);
1770
- t8 = vec_perm(t2, t4, swiz4);
1771
- vec_xst(t5, 0, vecOffset+64);
1772
- vec_xst(t6, 0, vecOffset+80);
1773
- vec_xst(t7, 0, vecOffset+96);
1774
- vec_xst(t8, 0, vecOffset+112);
1775
-
1776
- t1 = vec_perm(c5[0], c6[0], swiz1);
1777
- t2 = vec_perm(c5[0], c6[0], swiz2);
1778
- t3 = vec_perm(c7[0], c8[0], swiz1);
1779
- t4 = vec_perm(c7[0], c8[0], swiz2);
1780
- t5 = vec_perm(t1, t3, swiz3);
1781
- t6 = vec_perm(t1, t3, swiz4);
1782
- t7 = vec_perm(t2, t4, swiz3);
1783
- t8 = vec_perm(t2, t4, swiz4);
1784
- vec_xst(t5, 0, vecOffset+128);
1785
- vec_xst(t6, 0, vecOffset+144);
1786
- vec_xst(t7, 0, vecOffset+160);
1787
- vec_xst(t8, 0, vecOffset+176);
1788
-
1789
- t1 = vec_perm(c5[1], c6[1], swiz1);
1790
- t2 = vec_perm(c5[1], c6[1], swiz2);
1791
- t3 = vec_perm(c7[1], c8[1], swiz1);
1792
- t4 = vec_perm(c7[1], c8[1], swiz2);
1793
- t5 = vec_perm(t1, t3, swiz3);
1794
- t6 = vec_perm(t1, t3, swiz4);
1795
- t7 = vec_perm(t2, t4, swiz3);
1796
- t8 = vec_perm(t2, t4, swiz4);
1797
- vec_xst(t5, 0, vecOffset+192);
1798
- vec_xst(t6, 0, vecOffset+208);
1799
- vec_xst(t7, 0, vecOffset+224);
1800
- vec_xst(t8, 0, vecOffset+240);
1801
-
2340
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2341
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2342
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2343
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2344
+ c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
2345
+ c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
2346
+ c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
2347
+ c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
2348
+
2349
+ process_q4_elements(c1, &comparray[0]);
2350
+ process_q4_elements(c2, &comparray[1]);
2351
+ process_q4_elements(c3, &comparray[2]);
2352
+ process_q4_elements(c4, &comparray[3]);
2353
+ process_q4_elements(c5, &comparray[4]);
2354
+ process_q4_elements(c6, &comparray[5]);
2355
+ process_q4_elements(c7, &comparray[6]);
2356
+ process_q4_elements(c8, &comparray[7]);
2357
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2358
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2359
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
2360
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
1802
2361
  aoffset1 += lda;
1803
2362
  aoffset2 += lda;
1804
2363
  aoffset3 += lda;
@@ -1821,85 +2380,20 @@ class tinyBLAS_Q0_PPC {
1821
2380
  aoffset3 = aoffset2 + lda;
1822
2381
  aoffset4 = aoffset3 + lda;
1823
2382
  aoffset += 4 * lda;
1824
-
1825
2383
  i = (cols >> 2);
1826
2384
  if (i > 0) {
1827
2385
  do {
1828
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1829
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1830
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1831
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1832
-
1833
- c1[0] = vec_and(c1[1], lowMask);
1834
- c1[1] = vec_sr(c1[1], v4);
1835
- c1[0] = vec_sub(c1[0], v8);
1836
- c1[1] = vec_sub(c1[1], v8);
1837
- vsum = vec_sum4s(c1[0], vsum);
1838
- vsum2 = vec_sum4s(c1[1], vsum2);
1839
- vsum = vec_add(vsum, vsum2);
1840
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1841
- vsum = vec_splats(0);
1842
- vsum2 = vec_splats(0);
1843
-
1844
- c2[0] = vec_and(c2[1], lowMask);
1845
- c2[1] = vec_sr(c2[1], v4);
1846
- c2[0] = vec_sub(c2[0], v8);
1847
- c2[1] = vec_sub(c2[1], v8);
1848
- vsum = vec_sum4s(c2[0], vsum);
1849
- vsum2 = vec_sum4s(c2[1], vsum2);
1850
- vsum = vec_add(vsum, vsum2);
1851
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1852
- vsum = vec_splats(0);
1853
- vsum2 = vec_splats(0);
1854
-
1855
- c3[0] = vec_and(c3[1], lowMask);
1856
- c3[1] = vec_sr(c3[1], v4);
1857
- c3[0] = vec_sub(c3[0], v8);
1858
- c3[1] = vec_sub(c3[1], v8);
1859
- vsum = vec_sum4s(c3[0], vsum);
1860
- vsum2 = vec_sum4s(c3[1], vsum2);
1861
- vsum = vec_add(vsum, vsum2);
1862
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1863
- vsum = vec_splats(0);
1864
- vsum2 = vec_splats(0);
1865
-
1866
- c4[0] = vec_and(c4[1], lowMask);
1867
- c4[1] = vec_sr(c4[1], v4);
1868
- c4[0] = vec_sub(c4[0], v8);
1869
- c4[1] = vec_sub(c4[1], v8);
1870
- vsum = vec_sum4s(c4[0], vsum);
1871
- vsum2 = vec_sum4s(c4[1], vsum2);
1872
- vsum = vec_add(vsum, vsum2);
1873
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1874
- vsum = vec_splats(0);
1875
- vsum2 = vec_splats( 0);
1876
-
1877
- t1 = vec_perm(c1[0], c2[0], swiz1);
1878
- t2 = vec_perm(c1[0], c2[0], swiz2);
1879
- t3 = vec_perm(c3[0], c4[0], swiz1);
1880
- t4 = vec_perm(c3[0], c4[0], swiz2);
1881
- t5 = vec_perm(t1, t3, swiz3);
1882
- t6 = vec_perm(t1, t3, swiz4);
1883
- t7 = vec_perm(t2, t4, swiz3);
1884
- t8 = vec_perm(t2, t4, swiz4);
1885
- vec_xst(t5, 0, vecOffset);
1886
- vec_xst(t6, 0, vecOffset+16);
1887
- vec_xst(t7, 0, vecOffset+32);
1888
- vec_xst(t8, 0, vecOffset+48);
1889
-
1890
- t1 = vec_perm(c1[1], c2[1], swiz1);
1891
- t2 = vec_perm(c1[1], c2[1], swiz2);
1892
- t3 = vec_perm(c3[1], c4[1], swiz1);
1893
- t4 = vec_perm(c3[1], c4[1], swiz2);
1894
- t5 = vec_perm(t1, t3, swiz3);
1895
- t6 = vec_perm(t1, t3, swiz4);
1896
- t7 = vec_perm(t2, t4, swiz3);
1897
- t8 = vec_perm(t2, t4, swiz4);
1898
- vec_xst(t5, 0, vecOffset+64);
1899
- vec_xst(t6, 0, vecOffset+80);
1900
- vec_xst(t7, 0, vecOffset+96);
1901
- vec_xst(t8, 0, vecOffset+112);
1902
-
2386
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2387
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2388
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2389
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2390
+
2391
+ process_q4_elements(c1, &comparray[0]);
2392
+ process_q4_elements(c2, &comparray[1]);
2393
+ process_q4_elements(c3, &comparray[2]);
2394
+ process_q4_elements(c4, &comparray[3]);
2395
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2396
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1903
2397
  aoffset1 += lda;
1904
2398
  aoffset2 += lda;
1905
2399
  aoffset3 += lda;
@@ -1918,80 +2412,17 @@ class tinyBLAS_Q0_PPC {
1918
2412
  if (i > 0) {
1919
2413
  do {
1920
2414
  switch(rows) {
1921
- case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1922
- case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1923
- case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
2415
+ case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2416
+ case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2417
+ case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1924
2418
  break;
1925
2419
  }
1926
- c1[0] = vec_and(c1[1], lowMask);
1927
- c1[1] = vec_sr(c1[1], v4);
1928
- c1[0] = vec_sub(c1[0], v8);
1929
- c1[1] = vec_sub(c1[1], v8);
1930
- vsum = vec_sum4s(c1[0], vsum);
1931
- vsum2 = vec_sum4s(c1[1], vsum2);
1932
- vsum = vec_add(vsum, vsum2);
1933
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1934
- vsum = vec_splats(0);
1935
- vsum2 = vec_splats(0);
1936
-
1937
- c2[0] = vec_and(c2[1], lowMask);
1938
- c2[1] = vec_sr(c2[1], v4);
1939
- c2[0] = vec_sub(c2[0], v8);
1940
- c2[1] = vec_sub(c2[1], v8);
1941
- vsum = vec_sum4s(c2[0], vsum);
1942
- vsum2 = vec_sum4s(c2[1], vsum2);
1943
- vsum = vec_add(vsum, vsum2);
1944
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1945
- vsum = vec_splats(0);
1946
- vsum2 = vec_splats(0);
1947
-
1948
- c3[0] = vec_and(c3[1], lowMask);
1949
- c3[1] = vec_sr(c3[1], v4);
1950
- c3[0] = vec_sub(c3[0], v8);
1951
- c3[1] = vec_sub(c3[1], v8);
1952
- vsum = vec_sum4s(c3[0], vsum);
1953
- vsum2 = vec_sum4s(c3[1], vsum2);
1954
- vsum = vec_add(vsum, vsum2);
1955
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1956
- vsum = vec_splats(0);
1957
- vsum2 = vec_splats(0);
1958
-
1959
- c4[0] = vec_and(c4[1], lowMask);
1960
- c4[1] = vec_sr(c4[1], v4);
1961
- c4[0] = vec_sub(c4[0], v8);
1962
- c4[1] = vec_sub(c4[1], v8);
1963
- vsum = vec_sum4s(c4[0], vsum);
1964
- vsum2 = vec_sum4s(c4[1], vsum2);
1965
- vsum = vec_add(vsum, vsum2);
1966
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1967
- vsum = vec_splats(0);
1968
- vsum2 = vec_splats(0);
1969
-
1970
- t1 = vec_perm(c1[0], c2[0], swiz1);
1971
- t2 = vec_perm(c1[0], c2[0], swiz2);
1972
- t3 = vec_perm(c3[0], c4[0], swiz1);
1973
- t4 = vec_perm(c3[0], c4[0], swiz2);
1974
- t5 = vec_perm(t1, t3, swiz3);
1975
- t6 = vec_perm(t1, t3, swiz4);
1976
- t7 = vec_perm(t2, t4, swiz3);
1977
- t8 = vec_perm(t2, t4, swiz4);
1978
- vec_xst(t5, 0, vecOffset);
1979
- vec_xst(t6, 0, vecOffset+16);
1980
- vec_xst(t7, 0, vecOffset+32);
1981
- vec_xst(t8, 0, vecOffset+48);
1982
-
1983
- t1 = vec_perm(c1[1], c2[1], swiz1);
1984
- t2 = vec_perm(c1[1], c2[1], swiz2);
1985
- t3 = vec_perm(c3[1], c4[1], swiz1);
1986
- t4 = vec_perm(c3[1], c4[1], swiz2);
1987
- t5 = vec_perm(t1, t3, swiz3);
1988
- t6 = vec_perm(t1, t3, swiz4);
1989
- t7 = vec_perm(t2, t4, swiz3);
1990
- t8 = vec_perm(t2, t4, swiz4);
1991
- vec_xst(t5, 0, vecOffset+64);
1992
- vec_xst(t6, 0, vecOffset+80);
1993
- vec_xst(t7, 0, vecOffset+96);
1994
- vec_xst(t8, 0, vecOffset+112);
2420
+ process_q4_elements(c1, &comparray[0]);
2421
+ process_q4_elements(c2, &comparray[1]);
2422
+ process_q4_elements(c3, &comparray[2]);
2423
+ process_q4_elements(c4, &comparray[3]);
2424
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2425
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1995
2426
  aoffset1 += lda;
1996
2427
  aoffset2 += lda;
1997
2428
  aoffset3 += lda;
@@ -2002,145 +2433,41 @@ class tinyBLAS_Q0_PPC {
2002
2433
  }
2003
2434
  }
2004
2435
 
2436
+ template<typename TA>
2005
2437
  template<typename VA, typename VB>
2006
- void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2438
+ void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2007
2439
  int64_t i, j;
2008
- TB *aoffset = NULL;
2440
+ block_q8_0 *aoffset = NULL;
2009
2441
  VA *vecOffset = NULL;
2010
- TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2011
- TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2012
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2013
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
2014
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
2015
- VB t1, t2, t3, t4, t5, t6, t7, t8;
2016
- vector unsigned char xor_vector;
2017
- uint8_t flip_vec = 0x80;
2018
- xor_vector = vec_splats(flip_vec);
2019
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2020
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2021
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
2022
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
2023
-
2024
- aoffset = const_cast<TB*>(a);
2442
+ block_q8_0* aoffsets[8];
2443
+ __vector_pair arr[8];
2444
+ VB c[8][2] = {0};
2445
+ VB c1[8] = {0}; VB c2[8] = {0};
2446
+ aoffset = const_cast<block_q8_0*>(a);
2025
2447
  vecOffset = vec;
2026
2448
  j = (rows >> 3);
2027
2449
  if (j > 0) {
2028
2450
  do {
2029
- aoffset1 = aoffset;
2030
- aoffset2 = aoffset1 + lda;
2031
- aoffset3 = aoffset2 + lda;
2032
- aoffset4 = aoffset3 + lda;
2033
- aoffset5 = aoffset4 + lda;
2034
- aoffset6 = aoffset5 + lda;
2035
- aoffset7 = aoffset6 + lda;
2036
- aoffset8 = aoffset7 + lda;
2451
+ aoffsets[0] = aoffset;
2452
+ for (int it = 1; it < 8; it++)
2453
+ aoffsets[it] = aoffsets[it-1] + lda;
2037
2454
  aoffset += 8 * lda;
2038
2455
 
2039
2456
  i = (cols >> 3);
2040
2457
  if (i > 0) {
2041
2458
  do {
2042
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2043
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2044
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2045
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2046
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
2047
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
2048
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
2049
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
2050
-
2051
- __builtin_vsx_disassemble_pair(c1, &C1);
2052
- __builtin_vsx_disassemble_pair(c2, &C2);
2053
- __builtin_vsx_disassemble_pair(c3, &C3);
2054
- __builtin_vsx_disassemble_pair(c4, &C4);
2055
- __builtin_vsx_disassemble_pair(c5, &C5);
2056
- __builtin_vsx_disassemble_pair(c6, &C6);
2057
- __builtin_vsx_disassemble_pair(c7, &C7);
2058
- __builtin_vsx_disassemble_pair(c8, &C8);
2059
-
2060
- t1 = vec_perm(c1[0], c2[0], swiz1);
2061
- t2 = vec_perm(c1[0], c2[0], swiz2);
2062
- t3 = vec_perm(c3[0], c4[0], swiz1);
2063
- t4 = vec_perm(c3[0], c4[0], swiz2);
2064
- t5 = vec_perm(t1, t3, swiz3);
2065
- t6 = vec_perm(t1, t3, swiz4);
2066
- t7 = vec_perm(t2, t4, swiz3);
2067
- t8 = vec_perm(t2, t4, swiz4);
2068
- if (flip == true) {
2069
- t5 = vec_xor(t5, xor_vector);
2070
- t6 = vec_xor(t6, xor_vector);
2071
- t7 = vec_xor(t7, xor_vector);
2072
- t8 = vec_xor(t8, xor_vector);
2073
- }
2074
- vec_xst(t5, 0, vecOffset);
2075
- vec_xst(t6, 0, vecOffset+16);
2076
- vec_xst(t7, 0, vecOffset+32);
2077
- vec_xst(t8, 0, vecOffset+48);
2078
-
2079
- t1 = vec_perm(c1[1], c2[1], swiz1);
2080
- t2 = vec_perm(c1[1], c2[1], swiz2);
2081
- t3 = vec_perm(c3[1], c4[1], swiz1);
2082
- t4 = vec_perm(c3[1], c4[1], swiz2);
2083
- t5 = vec_perm(t1, t3, swiz3);
2084
- t6 = vec_perm(t1, t3, swiz4);
2085
- t7 = vec_perm(t2, t4, swiz3);
2086
- t8 = vec_perm(t2, t4, swiz4);
2087
- if (flip == true) {
2088
- t5 = vec_xor(t5, xor_vector);
2089
- t6 = vec_xor(t6, xor_vector);
2090
- t7 = vec_xor(t7, xor_vector);
2091
- t8 = vec_xor(t8, xor_vector);
2092
- }
2093
- vec_xst(t5, 0, vecOffset+64);
2094
- vec_xst(t6, 0, vecOffset+80);
2095
- vec_xst(t7, 0, vecOffset+96);
2096
- vec_xst(t8, 0, vecOffset+112);
2097
-
2098
- t1 = vec_perm(c5[0], c6[0], swiz1);
2099
- t2 = vec_perm(c5[0], c6[0], swiz2);
2100
- t3 = vec_perm(c7[0], c8[0], swiz1);
2101
- t4 = vec_perm(c7[0], c8[0], swiz2);
2102
- t5 = vec_perm(t1, t3, swiz3);
2103
- t6 = vec_perm(t1, t3, swiz4);
2104
- t7 = vec_perm(t2, t4, swiz3);
2105
- t8 = vec_perm(t2, t4, swiz4);
2106
- if (flip == true) {
2107
- t5 = vec_xor(t5, xor_vector);
2108
- t6 = vec_xor(t6, xor_vector);
2109
- t7 = vec_xor(t7, xor_vector);
2110
- t8 = vec_xor(t8, xor_vector);
2111
- }
2112
- vec_xst(t5, 0, vecOffset+128);
2113
- vec_xst(t6, 0, vecOffset+144);
2114
- vec_xst(t7, 0, vecOffset+160);
2115
- vec_xst(t8, 0, vecOffset+176);
2116
-
2117
- t1 = vec_perm(c5[1], c6[1], swiz1);
2118
- t2 = vec_perm(c5[1], c6[1], swiz2);
2119
- t3 = vec_perm(c7[1], c8[1], swiz1);
2120
- t4 = vec_perm(c7[1], c8[1], swiz2);
2121
- t5 = vec_perm(t1, t3, swiz3);
2122
- t6 = vec_perm(t1, t3, swiz4);
2123
- t7 = vec_perm(t2, t4, swiz3);
2124
- t8 = vec_perm(t2, t4, swiz4);
2125
- if (flip == true) {
2126
- t5 = vec_xor(t5, xor_vector);
2127
- t6 = vec_xor(t6, xor_vector);
2128
- t7 = vec_xor(t7, xor_vector);
2129
- t8 = vec_xor(t8, xor_vector);
2459
+ for (int it = 0; it < 8; it++) {
2460
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2461
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2462
+ c1[it] = c[it][0];
2463
+ c2[it] = c[it][1];
2130
2464
  }
2131
- vec_xst(t5, 0, vecOffset+192);
2132
- vec_xst(t6, 0, vecOffset+208);
2133
- vec_xst(t7, 0, vecOffset+224);
2134
- vec_xst(t8, 0, vecOffset+240);
2135
-
2136
- aoffset1 += lda;
2137
- aoffset2 += lda;
2138
- aoffset3 += lda;
2139
- aoffset4 += lda;
2140
- aoffset5 += lda;
2141
- aoffset6 += lda;
2142
- aoffset7 += lda;
2143
- aoffset8 += lda;
2465
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2466
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2467
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
2468
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
2469
+ for (int it = 0; it < 8; it++)
2470
+ aoffsets[it] += lda;
2144
2471
  vecOffset += 256;
2145
2472
  i--;
2146
2473
  } while(i > 0);
@@ -2148,131 +2475,54 @@ class tinyBLAS_Q0_PPC {
2148
2475
  j--;
2149
2476
  } while(j > 0);
2150
2477
  }
2151
-
2152
2478
  if (rows & 4) {
2153
- aoffset1 = aoffset;
2154
- aoffset2 = aoffset1 + lda;
2155
- aoffset3 = aoffset2 + lda;
2156
- aoffset4 = aoffset3 + lda;
2157
- aoffset += 4 * lda;
2158
-
2479
+ aoffsets[0] = aoffset;
2480
+ for (int it = 1; it < 4; it++ )
2481
+ aoffsets[it] = aoffsets[it-1] + lda;
2482
+ aoffset += 4 * lda;
2159
2483
  i = (cols >> 3);
2160
2484
  if (i > 0) {
2161
2485
  do {
2162
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2163
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2164
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2165
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2166
-
2167
- __builtin_vsx_disassemble_pair(c1, &C1);
2168
- __builtin_vsx_disassemble_pair(c2, &C2);
2169
- __builtin_vsx_disassemble_pair(c3, &C3);
2170
- __builtin_vsx_disassemble_pair(c4, &C4);
2171
-
2172
- t1 = vec_perm(c1[0], c2[0], swiz1);
2173
- t2 = vec_perm(c1[0], c2[0], swiz2);
2174
- t3 = vec_perm(c3[0], c4[0], swiz1);
2175
- t4 = vec_perm(c3[0], c4[0], swiz2);
2176
- t5 = vec_perm(t1, t3, swiz3);
2177
- t6 = vec_perm(t1, t3, swiz4);
2178
- t7 = vec_perm(t2, t4, swiz3);
2179
- t8 = vec_perm(t2, t4, swiz4);
2180
- if (flip == true) {
2181
- t5 = vec_xor(t5, xor_vector);
2182
- t6 = vec_xor(t6, xor_vector);
2183
- t7 = vec_xor(t7, xor_vector);
2184
- t8 = vec_xor(t8, xor_vector);
2486
+ for (int it = 0; it < 4; it++) {
2487
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2488
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2489
+ c1[it] = c[it][0];
2490
+ c2[it] = c[it][1];
2185
2491
  }
2186
- vec_xst(t5, 0, vecOffset);
2187
- vec_xst(t6, 0, vecOffset+16);
2188
- vec_xst(t7, 0, vecOffset+32);
2189
- vec_xst(t8, 0, vecOffset+48);
2190
-
2191
- t1 = vec_perm(c1[1], c2[1], swiz1);
2192
- t2 = vec_perm(c1[1], c2[1], swiz2);
2193
- t3 = vec_perm(c3[1], c4[1], swiz1);
2194
- t4 = vec_perm(c3[1], c4[1], swiz2);
2195
- t5 = vec_perm(t1, t3, swiz3);
2196
- t6 = vec_perm(t1, t3, swiz4);
2197
- t7 = vec_perm(t2, t4, swiz3);
2198
- t8 = vec_perm(t2, t4, swiz4);
2199
- if (flip == true) {
2200
- t5 = vec_xor(t5, xor_vector);
2201
- t6 = vec_xor(t6, xor_vector);
2202
- t7 = vec_xor(t7, xor_vector);
2203
- t8 = vec_xor(t8, xor_vector);
2492
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2493
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2494
+ for (int it = 0; it < 4; it++) {
2495
+ aoffsets[it] += lda;
2204
2496
  }
2205
- vec_xst(t5, 0, vecOffset+64);
2206
- vec_xst(t6, 0, vecOffset+80);
2207
- vec_xst(t7, 0, vecOffset+96);
2208
- vec_xst(t8, 0, vecOffset+112);
2209
-
2210
- aoffset1 += lda;
2211
- aoffset2 += lda;
2212
- aoffset3 += lda;
2213
- aoffset4 += lda;
2214
2497
  vecOffset += 128;
2215
2498
  i--;
2216
2499
  } while(i > 0);
2217
2500
  }
2218
2501
  }
2502
+
2219
2503
  if (rows & 3) {
2220
- aoffset1 = aoffset;
2221
- aoffset2 = aoffset1 + lda;
2222
- aoffset3 = aoffset2 + lda;
2504
+ aoffsets[0] = aoffset;
2505
+ for (int it = 1; it < 3; it++ )
2506
+ aoffsets[it] = aoffsets[it-1] + lda;
2223
2507
  i = (cols >> 3);
2224
2508
  if (i > 0) {
2225
2509
  do {
2226
2510
  switch(rows) {
2227
- case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2228
- __builtin_vsx_disassemble_pair(c3, &C3);
2229
- case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2230
- __builtin_vsx_disassemble_pair(c2, &C2);
2231
- case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2232
- __builtin_vsx_disassemble_pair(c1, &C1);
2511
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
2512
+ __builtin_vsx_disassemble_pair(c[2], &arr[2]);
2513
+ c1[2] = c[2][0]; c2[2] = c[2][1];
2514
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
2515
+ __builtin_vsx_disassemble_pair(c[1], &arr[1]);
2516
+ c1[1] = c[1][0]; c2[1] = c[1][1];
2517
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
2518
+ __builtin_vsx_disassemble_pair(c[0], &arr[0]);
2519
+ c1[0] = c[0][0]; c2[0] = c[0][1];
2233
2520
  break;
2234
2521
  }
2235
- t1 = vec_perm(c1[0], c2[0], swiz1);
2236
- t2 = vec_perm(c1[0], c2[0], swiz2);
2237
- t3 = vec_perm(c3[0], c4[0], swiz1);
2238
- t4 = vec_perm(c3[0], c4[0], swiz2);
2239
- t5 = vec_perm(t1, t3, swiz3);
2240
- t6 = vec_perm(t1, t3, swiz4);
2241
- t7 = vec_perm(t2, t4, swiz3);
2242
- t8 = vec_perm(t2, t4, swiz4);
2243
- if (flip == true) {
2244
- t5 = vec_xor(t5, xor_vector);
2245
- t6 = vec_xor(t6, xor_vector);
2246
- t7 = vec_xor(t7, xor_vector);
2247
- t8 = vec_xor(t8, xor_vector);
2248
- }
2249
- vec_xst(t5, 0, vecOffset);
2250
- vec_xst(t6, 0, vecOffset+16);
2251
- vec_xst(t7, 0, vecOffset+32);
2252
- vec_xst(t8, 0, vecOffset+48);
2253
-
2254
- t1 = vec_perm(c1[1], c2[1], swiz1);
2255
- t2 = vec_perm(c1[1], c2[1], swiz2);
2256
- t3 = vec_perm(c3[1], c4[1], swiz1);
2257
- t4 = vec_perm(c3[1], c4[1], swiz2);
2258
- t5 = vec_perm(t1, t3, swiz3);
2259
- t6 = vec_perm(t1, t3, swiz4);
2260
- t7 = vec_perm(t2, t4, swiz3);
2261
- t8 = vec_perm(t2, t4, swiz4);
2262
- if (flip == true) {
2263
- t5 = vec_xor(t5, xor_vector);
2264
- t6 = vec_xor(t6, xor_vector);
2265
- t7 = vec_xor(t7, xor_vector);
2266
- t8 = vec_xor(t8, xor_vector);
2267
- }
2268
- vec_xst(t5, 0, vecOffset+64);
2269
- vec_xst(t6, 0, vecOffset+80);
2270
- vec_xst(t7, 0, vecOffset+96);
2271
- vec_xst(t8, 0, vecOffset+112);
2272
-
2273
- aoffset1 += lda;
2274
- aoffset2 += lda;
2275
- aoffset3 += lda;
2522
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2523
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2524
+ for (int it = 0; it < 3; it++)
2525
+ aoffsets[it] += lda;
2276
2526
  vecOffset += 128;
2277
2527
  i--;
2278
2528
  } while(i > 0);
@@ -2280,161 +2530,46 @@ class tinyBLAS_Q0_PPC {
2280
2530
  }
2281
2531
  }
2282
2532
 
2283
- void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2284
- int64_t mc, nc, mp, np;
2285
- int m_rem = MIN(m - m0, 8);
2286
- int n_rem = MIN(n - n0, 8);
2287
- // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
2288
- // issues. After resolving them, below code will be enabled.
2289
- /*if (m_rem >= 16 && n_rem >= 8) {
2290
- mc = 16;
2291
- nc = 8;
2292
- gemm<16,8>(m0, m, n0, n);
2293
- } else if(m_rem >= 8 && n_rem >= 16) {
2294
- mc = 8;
2295
- nc = 16;
2296
- gemm<8,16>(m0, m, n0, n);
2297
- }*/
2533
+ template<typename TA>
2534
+ void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2535
+ int m_rem = MIN(m - m0, 16);
2536
+ int n_rem = MIN(n - n0, 16);
2537
+
2538
+ int mc = 0, nc = 0;
2539
+
2298
2540
  if (m_rem >= 8 && n_rem >= 8) {
2299
- mc = 8;
2300
- nc = 8;
2301
- gemm<8,8>(m0, m, n0, n);
2541
+ mc = 8;
2542
+ nc = 8;
2543
+ gemm<8, 8>(m0, m, n0, n);
2302
2544
  } else if (m_rem >= 4 && n_rem >= 8) {
2303
2545
  mc = 4;
2304
2546
  nc = 8;
2305
- gemm<4,8>(m0, m, n0, n);
2547
+ gemm<4, 8>(m0, m, n0, n);
2306
2548
  } else if (m_rem >= 8 && n_rem >= 4) {
2307
2549
  mc = 8;
2308
2550
  nc = 4;
2309
- gemm<8,4>(m0, m, n0, n);
2551
+ gemm<8, 4>(m0, m, n0, n);
2310
2552
  } else if (m_rem >= 4 && n_rem >= 4) {
2311
2553
  mc = 4;
2312
2554
  nc = 4;
2313
- gemm_small<4, 4>(m0, m, n0, n);
2314
- } else if ((m_rem < 4) && (n_rem > 4)) {
2315
- nc = 4;
2316
- switch(m_rem) {
2317
- case 1:
2318
- mc = 1;
2319
- gemm_small<1, 4>(m0, m, n0, n);
2320
- break;
2321
- case 2:
2322
- mc = 2;
2323
- gemm_small<2, 4>(m0, m, n0, n);
2324
- break;
2325
- case 3:
2326
- mc = 3;
2327
- gemm_small<3, 4>(m0, m, n0, n);
2328
- break;
2329
- default:
2330
- return;
2331
- }
2332
- } else if ((m_rem > 4) && (n_rem < 4)) {
2333
- mc = 4;
2334
- switch(n_rem) {
2335
- case 1:
2336
- nc = 1;
2337
- gemm_small<4, 1>(m0, m, n0, n);
2338
- break;
2339
- case 2:
2340
- nc = 2;
2341
- gemm_small<4, 2>(m0, m, n0, n);
2342
- break;
2343
- case 3:
2344
- nc = 3;
2345
- gemm_small<4, 3>(m0, m, n0, n);
2346
- break;
2347
- default:
2348
- return;
2349
- }
2555
+ gemm_small(m0, m, n0, n, mc, nc);
2350
2556
  } else {
2351
- switch((m_rem << 4) | n_rem) {
2352
- case 0x43:
2353
- mc = 4;
2354
- nc = 3;
2355
- gemm_small<4, 3>(m0, m, n0, n);
2356
- break;
2357
- case 0x42:
2358
- mc = 4;
2359
- nc = 2;
2360
- gemm_small<4, 2>(m0, m, n0, n);
2361
- break;
2362
- case 0x41:
2363
- mc = 4;
2364
- nc = 1;
2365
- gemm_small<4, 1>(m0, m, n0, n);
2366
- break;
2367
- case 0x34:
2368
- mc = 3;
2369
- nc = 4;
2370
- gemm_small<3, 4>(m0, m, n0, n);
2371
- break;
2372
- case 0x33:
2373
- mc = 3;
2374
- nc = 3;
2375
- gemm_small<3, 3>(m0, m, n0, n);
2376
- break;
2377
- case 0x32:
2378
- mc = 3;
2379
- nc = 2;
2380
- gemm_small<3, 2>(m0, m, n0, n);
2381
- break;
2382
- case 0x31:
2383
- mc = 3;
2384
- nc = 1;
2385
- gemm_small<3, 1>(m0, m, n0, n);
2386
- break;
2387
- case 0x24:
2388
- mc = 2;
2389
- nc = 4;
2390
- gemm_small<2, 4>(m0, m, n0, n);
2391
- break;
2392
- case 0x23:
2393
- mc = 2;
2394
- nc = 3;
2395
- gemm_small<2, 3>(m0, m, n0, n);
2396
- break;
2397
- case 0x22:
2398
- mc = 2;
2399
- nc = 2;
2400
- gemm_small<2, 2>(m0, m, n0, n);
2401
- break;
2402
- case 0x21:
2403
- mc = 2;
2404
- nc = 1;
2405
- gemm_small<2, 1>(m0, m, n0, n);
2406
- break;
2407
- case 0x14:
2408
- mc = 1;
2409
- nc = 4;
2410
- gemm_small<1, 4>(m0, m, n0, n);
2411
- break;
2412
- case 0x13:
2413
- mc = 1;
2414
- nc = 3;
2415
- gemm_small<1, 3>(m0, m, n0, n);
2416
- break;
2417
- case 0x12:
2418
- mc = 1;
2419
- nc = 2;
2420
- gemm_small<1, 2>(m0, m, n0, n);
2421
- break;
2422
- case 0x11:
2423
- mc = 1;
2424
- nc = 1;
2425
- gemm_small<1, 1>(m0, m, n0, n);
2426
- break;
2427
- default:
2428
- return;
2429
- }
2557
+ mc = (m_rem >= 4) ? 4 : m_rem;
2558
+ nc = (n_rem >= 4) ? 4 : n_rem;
2559
+ if (mc == 0 || nc == 0)
2560
+ return;
2561
+ gemm_small(m0, m, n0, n, mc, nc);
2430
2562
  }
2431
- mp = m0 + (m - m0) / mc * mc;
2432
- np = n0 + (n - n0) / nc * nc;
2563
+
2564
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
2565
+ int64_t np = n0 + ((n - n0) / nc) * nc;
2433
2566
  mnpack(mp, m, n0, np);
2434
2567
  mnpack(m0, m, np, n);
2435
2568
  }
2436
2569
 
2437
- void KERNEL_4x8(int64_t ii, int64_t jj) {
2570
+
2571
+ template<typename TA>
2572
+ void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
2438
2573
  vec_t vec_A[8], vec_B[16] = {0};
2439
2574
  acc_t acc_0, acc_1;
2440
2575
  std::array<int, 4> comparray {};
@@ -2445,9 +2580,9 @@ class tinyBLAS_Q0_PPC {
2445
2580
  __builtin_mma_xxsetaccz(&acc_0);
2446
2581
  __builtin_mma_xxsetaccz(&acc_1);
2447
2582
  if (std::is_same_v<TA, block_q4_0>) {
2448
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2583
+ packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2449
2584
  } else {
2450
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2585
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2451
2586
  }
2452
2587
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2453
2588
  for(int x = 0; x < 8; x++) {
@@ -2472,14 +2607,15 @@ class tinyBLAS_Q0_PPC {
2472
2607
  aoffset += lda;
2473
2608
  }
2474
2609
  }
2475
- compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
2476
- compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2610
+ compute(&acc_0, 0, 0, comparray, vs, fin_res);
2611
+ compute(&acc_1, 0, 4, comparray, vs, fin_res);
2477
2612
  }
2478
- save_res<4, 4>(ii, jj, 0, fin_res);
2479
- save_res<4, 4>(ii, jj+4, 4, fin_res);
2613
+ save_res(ii, jj, 0, fin_res);
2614
+ save_res(ii, jj+4, 4, fin_res);
2480
2615
  }
2481
2616
 
2482
- void KERNEL_8x4(int64_t ii, int64_t jj) {
2617
+ template<typename TA>
2618
+ void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
2483
2619
  vec_t vec_A[16], vec_B[8] = {0};
2484
2620
  acc_t acc_0, acc_1;
2485
2621
  std::array<int, 8> comparray {};
@@ -2490,9 +2626,9 @@ class tinyBLAS_Q0_PPC {
2490
2626
  __builtin_mma_xxsetaccz(&acc_0);
2491
2627
  __builtin_mma_xxsetaccz(&acc_1);
2492
2628
  if (std::is_same_v<TA, block_q4_0>) {
2493
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2629
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2494
2630
  } else {
2495
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2631
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2496
2632
  }
2497
2633
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2498
2634
  for(int x = 0; x < 8; x++) {
@@ -2516,16 +2652,18 @@ class tinyBLAS_Q0_PPC {
2516
2652
  aoffset += lda;
2517
2653
  }
2518
2654
  }
2519
- compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2520
- compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2655
+ compute(&acc_0, 0, 0, comparray, vs, fin_res);
2656
+ compute(&acc_1, 4, 4, comparray, vs, fin_res);
2521
2657
  }
2522
- save_res<4, 4>(ii, jj, 0, fin_res);
2523
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2658
+ save_res(ii, jj, 0, fin_res);
2659
+ save_res(ii+4, jj, 4, fin_res);
2524
2660
  }
2525
2661
 
2526
- void KERNEL_8x8(int64_t ii, int64_t jj) {
2662
+ template<typename TA>
2663
+ void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
2527
2664
  vec_t vec_A[16], vec_B[16] = {0};
2528
2665
  acc_t acc_0, acc_1, acc_2, acc_3;
2666
+ acc_t acc_4, acc_5, acc_6, acc_7;
2529
2667
  std::array<int, 8> comparray {};
2530
2668
  vector float fin_res[16] = {0};
2531
2669
  vector float vs[16] = {0};
@@ -2536,9 +2674,9 @@ class tinyBLAS_Q0_PPC {
2536
2674
  __builtin_mma_xxsetaccz(&acc_2);
2537
2675
  __builtin_mma_xxsetaccz(&acc_3);
2538
2676
  if (std::is_same_v<TA, block_q4_0>) {
2539
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2677
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2540
2678
  } else {
2541
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2679
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2542
2680
  }
2543
2681
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2544
2682
  for(int x = 0; x < 8; x++) {
@@ -2565,19 +2703,19 @@ class tinyBLAS_Q0_PPC {
2565
2703
  aoffset += lda;
2566
2704
  }
2567
2705
  }
2568
- compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2569
- compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2570
- compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2571
- compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2706
+ compute(&acc_0, 0, 0, comparray, vs, fin_res);
2707
+ compute(&acc_1, 4, 4, comparray, vs, fin_res);
2708
+ compute(&acc_2, 0, 8, comparray, vs, fin_res);
2709
+ compute(&acc_3, 4, 12, comparray, vs, fin_res);
2572
2710
  }
2573
- save_res<4, 4>(ii, jj, 0, fin_res);
2574
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2575
- save_res<4, 4>(ii, jj+4, 8, fin_res);
2576
- save_res<4, 4>(ii+4, jj+4, 12, fin_res);
2711
+ save_res(ii, jj, 0, fin_res);
2712
+ save_res(ii+4, jj, 4, fin_res);
2713
+ save_res(ii, jj+4, 8, fin_res);
2714
+ save_res(ii+4, jj+4, 12, fin_res);
2577
2715
  }
2578
2716
 
2579
- template<int RM, int RN>
2580
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2717
+ template<typename TA>
2718
+ void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2581
2719
  int64_t ytiles = (m - m0) / RM;
2582
2720
  int64_t xtiles = (n - n0) / RN;
2583
2721
  int64_t tiles = xtiles * ytiles;
@@ -2606,9 +2744,9 @@ class tinyBLAS_Q0_PPC {
2606
2744
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2607
2745
  __builtin_mma_xxsetaccz(&acc_0);
2608
2746
  if (isAblock_q4) {
2609
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2747
+ packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2610
2748
  } else {
2611
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2749
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2612
2750
  }
2613
2751
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2614
2752
  for(int x = 0; x < 8; x+=4) {
@@ -2641,25 +2779,13 @@ class tinyBLAS_Q0_PPC {
2641
2779
  fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2642
2780
  }
2643
2781
  }
2644
- save_res<RM, RN>(ii, jj, 0, fin_res);
2782
+ save_res(ii, jj, 0, fin_res, RM, RN);
2645
2783
  }
2646
2784
  }
2647
2785
 
2648
- template<int RM, int RN>
2649
- inline void kernel(int64_t ii, int64_t jj) {
2650
- if constexpr(RM == 4 && RN == 8) {
2651
- KERNEL_4x8(ii,jj);
2652
- } else if constexpr(RM == 8 && RN == 4) {
2653
- KERNEL_8x4(ii,jj);
2654
- } else if constexpr(RM == 8 && RN == 8) {
2655
- KERNEL_8x8(ii,jj);
2656
- } else {
2657
- static_assert(false, "RN/RM values not supported");
2658
- }
2659
- }
2660
-
2786
+ template<typename TA>
2661
2787
  template <int RM, int RN>
2662
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2788
+ NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2663
2789
  int64_t ytiles = (m - m0) / RM;
2664
2790
  int64_t xtiles = (n - n0) / RN;
2665
2791
  int64_t tiles = xtiles * ytiles;
@@ -2671,283 +2797,190 @@ class tinyBLAS_Q0_PPC {
2671
2797
  for (int64_t job = start; job < end; ++job) {
2672
2798
  int64_t ii = m0 + job / xtiles * RM;
2673
2799
  int64_t jj = n0 + job % xtiles * RN;
2674
- kernel<RM, RN>(ii, jj);
2800
+ this->kernel<RM, RN>(ii, jj);
2675
2801
  }
2676
2802
  }
2677
2803
 
2678
- const TA *const A;
2679
- const TB *const B;
2680
- TC *C;
2681
- TA *At;
2682
- TB *Bt;
2683
- const int64_t k;
2684
- const int64_t lda;
2685
- const int64_t ldb;
2686
- const int64_t ldc;
2687
- const int ith;
2688
- const int nth;
2689
- };
2804
+ template class tinyBLAS_Q0_PPC<block_q4_0>;
2805
+ template class tinyBLAS_Q0_PPC<block_q8_0>;
2690
2806
 
2691
- template <typename TA, typename TB, typename TC>
2692
2807
  class tinyBLAS_PPC {
2693
2808
  public:
2694
2809
  tinyBLAS_PPC(int64_t k,
2695
- const TA *A, int64_t lda,
2696
- const TB *B, int64_t ldb,
2697
- TC *C, int64_t ldc,
2810
+ const float * A, int64_t lda,
2811
+ const float * B, int64_t ldb,
2812
+ float * C, int64_t ldc,
2698
2813
  int ith, int nth)
2699
2814
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2700
2815
  }
2701
2816
 
2702
2817
  void matmul(int64_t m, int64_t n) {
2703
- mnpack(0, m, 0, n);
2818
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2819
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2820
+ matmul_tiled(m, n, mc, nc, kc);
2821
+ } else {
2822
+ mnpack(0, m, 0, n);
2823
+ }
2704
2824
  }
2705
2825
 
2706
2826
  private:
2707
2827
 
2708
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2828
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2829
+ vec_t vec_C[4];
2830
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2831
+ for (int I = 0; I < 4; I++) {
2832
+ for (int J = 0; J < 4; J++) {
2833
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2834
+ }
2835
+ }
2836
+ }
2837
+
2838
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2839
+ vec_t vec_C[4];
2840
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2841
+ for (int I = 0; I < 4; I++) {
2842
+ for (int J = 0; J < 4; J++) {
2843
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2844
+ *c_ptr += *((float *)&vec_C[I]+J);
2845
+ }
2846
+ }
2847
+ }
2848
+
2849
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2850
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2851
+ t1 = vec_mergeh(src[0], src[1]);
2852
+ t2 = vec_mergeh(src[2], src[3]);
2853
+ t3 = vec_mergel(src[0], src[1]);
2854
+ t4 = vec_mergel(src[2], src[3]);
2855
+
2856
+ t5 = vec_xxpermdi(t1, t2, 0);
2857
+ t6 = vec_xxpermdi(t1, t2, 3);
2858
+ t7 = vec_xxpermdi(t3, t4, 0);
2859
+ t8 = vec_xxpermdi(t3, t4, 3);
2860
+
2861
+ vec_xst(t5, 0, vecOffset);
2862
+ vec_xst(t6, 0, vecOffset + 4);
2863
+ vec_xst(t7, 0, vecOffset + 8);
2864
+ vec_xst(t8, 0, vecOffset + 12);
2865
+ }
2866
+
2867
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2868
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2869
+ t1 = vec_mergeh(src[0], src[1]);
2870
+ t2 = vec_mergeh(src[2], src[3]);
2871
+ t3 = vec_mergeh(src[4], src[5]);
2872
+ t4 = vec_mergeh(src[6], src[7]);
2873
+
2874
+ t5 = vec_xxpermdi(t1, t2, 0);
2875
+ t6 = vec_xxpermdi(t3, t4, 0);
2876
+ t7 = vec_xxpermdi(t1, t2, 3);
2877
+ t8 = vec_xxpermdi(t3, t4, 3);
2878
+
2879
+ vec_xst(t5, 0, vecOffset);
2880
+ vec_xst(t6, 0, vecOffset + 4);
2881
+ vec_xst(t7, 0, vecOffset + 8);
2882
+ vec_xst(t8, 0, vecOffset + 12);
2883
+
2884
+ t1 = vec_mergel(src[0], src[1]);
2885
+ t2 = vec_mergel(src[2], src[3]);
2886
+ t3 = vec_mergel(src[4], src[5]);
2887
+ t4 = vec_mergel(src[6], src[7]);
2888
+
2889
+ t5 = vec_xxpermdi(t1, t2, 0);
2890
+ t6 = vec_xxpermdi(t3, t4, 0);
2891
+ t7 = vec_xxpermdi(t1, t2, 3);
2892
+ t8 = vec_xxpermdi(t3, t4, 3);
2709
2893
 
2710
- template<typename VA>
2711
- void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
2894
+ vec_xst(t5, 0, vecOffset + 16);
2895
+ vec_xst(t6, 0, vecOffset + 20);
2896
+ vec_xst(t7, 0, vecOffset + 24);
2897
+ vec_xst(t8, 0, vecOffset + 28);
2898
+ }
2899
+
2900
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2712
2901
  int64_t i, j;
2713
- TA *aoffset = NULL, *boffset = NULL;
2714
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2715
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2716
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2717
- VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2718
- VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2719
- VA t1, t2, t3, t4, t5, t6, t7, t8;
2720
- aoffset = const_cast<TA*>(a);
2902
+ float * aoffsets[8];
2903
+ float * aoffset = NULL, * boffset = NULL;
2904
+ __vector_pair arr[8];
2905
+ vector float c[8][2] = {0};
2906
+ vector float c1[8] = {0};
2907
+ vector float c2[8] = {0};
2908
+ aoffset = const_cast<float *>(a);
2721
2909
  boffset = vec;
2722
2910
  j = (rows >> 3);
2723
2911
  if (j > 0) {
2724
-
2725
2912
  do {
2726
- aoffset1 = aoffset;
2727
- aoffset2 = aoffset1 + lda;
2728
- aoffset3 = aoffset2 + lda;
2729
- aoffset4 = aoffset3 + lda;
2730
- aoffset5 = aoffset4 + lda;
2731
- aoffset6 = aoffset5 + lda;
2732
- aoffset7 = aoffset6 + lda;
2733
- aoffset8 = aoffset7 + lda;
2913
+ aoffsets[0] = aoffset;
2914
+ for (int it = 1; it < 8; it++)
2915
+ aoffsets[it] = aoffsets[it-1] + lda;
2734
2916
  aoffset += 8 * lda;
2735
2917
  i = (cols >> 3);
2736
2918
  if (i > 0) {
2737
2919
  do {
2738
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2739
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2740
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2741
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2742
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
2743
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
2744
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
2745
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
2746
- __builtin_vsx_disassemble_pair(c1, &C1);
2747
- __builtin_vsx_disassemble_pair(c2, &C2);
2748
- __builtin_vsx_disassemble_pair(c3, &C3);
2749
- __builtin_vsx_disassemble_pair(c4, &C4);
2750
- __builtin_vsx_disassemble_pair(c5, &C5);
2751
- __builtin_vsx_disassemble_pair(c6, &C6);
2752
- __builtin_vsx_disassemble_pair(c7, &C7);
2753
- __builtin_vsx_disassemble_pair(c8, &C8);
2754
-
2755
- t1 = vec_mergeh(c1[0], c2[0]);
2756
- t2 = vec_mergeh(c3[0], c4[0]);
2757
- t3 = vec_mergeh(c5[0], c6[0]);
2758
- t4 = vec_mergeh(c7[0], c8[0]);
2759
- t5 = vec_xxpermdi(t1, t2, 0);
2760
- t6 = vec_xxpermdi(t3, t4, 0);
2761
- t7 = vec_xxpermdi(t1, t2, 3);
2762
- t8 = vec_xxpermdi(t3, t4, 3);
2763
- vec_xst(t5, 0, boffset);
2764
- vec_xst(t6, 0, boffset+4);
2765
- vec_xst(t7, 0, boffset+8);
2766
- vec_xst(t8, 0, boffset+12);
2767
-
2768
- t1 = vec_mergel(c1[0], c2[0]);
2769
- t2 = vec_mergel(c3[0], c4[0]);
2770
- t3 = vec_mergel(c5[0], c6[0]);
2771
- t4 = vec_mergel(c7[0], c8[0]);
2772
- t5 = vec_xxpermdi(t1, t2, 0);
2773
- t6 = vec_xxpermdi(t3, t4, 0);
2774
- t7 = vec_xxpermdi(t1, t2, 3);
2775
- t8 = vec_xxpermdi(t3, t4, 3);
2776
- vec_xst(t5, 0, boffset+16);
2777
- vec_xst(t6, 0, boffset+20);
2778
- vec_xst(t7, 0, boffset+24);
2779
- vec_xst(t8, 0, boffset+28);
2780
-
2781
- t1 = vec_mergeh(c1[1], c2[1]);
2782
- t2 = vec_mergeh(c3[1], c4[1]);
2783
- t3 = vec_mergeh(c5[1], c6[1]);
2784
- t4 = vec_mergeh(c7[1], c8[1]);
2785
- t5 = vec_xxpermdi(t1, t2, 0);
2786
- t6 = vec_xxpermdi(t3, t4, 0);
2787
- t7 = vec_xxpermdi(t1, t2, 3);
2788
- t8 = vec_xxpermdi(t3, t4, 3);
2789
- vec_xst(t5, 0, boffset+32);
2790
- vec_xst(t6, 0, boffset+36);
2791
- vec_xst(t7, 0, boffset+40);
2792
- vec_xst(t8, 0, boffset+44);
2793
-
2794
- t1 = vec_mergel(c1[1], c2[1]);
2795
- t2 = vec_mergel(c3[1], c4[1]);
2796
- t3 = vec_mergel(c5[1], c6[1]);
2797
- t4 = vec_mergel(c7[1], c8[1]);
2798
- t5 = vec_xxpermdi(t1, t2, 0);
2799
- t6 = vec_xxpermdi(t3, t4, 0);
2800
- t7 = vec_xxpermdi(t1, t2, 3);
2801
- t8 = vec_xxpermdi(t3, t4, 3);
2802
- vec_xst(t5, 0, boffset+48);
2803
- vec_xst(t6, 0, boffset+52);
2804
- vec_xst(t7, 0, boffset+56);
2805
- vec_xst(t8, 0, boffset+60);
2806
-
2807
- aoffset1 += 8*lda;
2808
- aoffset2 += 8*lda;
2809
- aoffset3 += 8*lda;
2810
- aoffset4 += 8*lda;
2920
+ for (int it = 0; it < 8; it++) {
2921
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2922
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2923
+ c1[it] = c[it][0];
2924
+ c2[it] = c[it][1];
2925
+ }
2926
+
2927
+ vector_permute_store_8(c1, boffset);
2928
+ vector_permute_store_8(c2, boffset + 32);
2811
2929
  boffset += 64;
2812
2930
  i--;
2931
+ if (i > 0) {
2932
+ for (int it = 0; it < 8; it++) {
2933
+ aoffsets[it] = aoffsets[it] + 8;
2934
+ }
2935
+ }
2813
2936
  } while(i > 0);
2814
2937
  }
2815
2938
  if (cols & 4) {
2816
- c1[0] = vec_xl(0, aoffset1);
2817
- c2[0] = vec_xl(0, aoffset2);
2818
- c3[0] = vec_xl(0, aoffset3);
2819
- c4[0] = vec_xl(0, aoffset4);
2820
- c5[0] = vec_xl(0, aoffset5);
2821
- c6[0] = vec_xl(0, aoffset6);
2822
- c7[0] = vec_xl(0, aoffset7);
2823
- c8[0] = vec_xl(0, aoffset8);
2824
-
2825
- t1 = vec_mergeh(c1[0], c2[0]);
2826
- t2 = vec_mergeh(c3[0], c4[0]);
2827
- t3 = vec_mergeh(c5[0], c6[0]);
2828
- t4 = vec_mergeh(c7[0], c8[0]);
2829
- t5 = vec_xxpermdi(t1, t2, 0);
2830
- t6 = vec_xxpermdi(t3, t4, 0);
2831
- t7 = vec_xxpermdi(t1, t2, 3);
2832
- t8 = vec_xxpermdi(t3, t4, 3);
2833
- vec_xst(t5, 0, boffset);
2834
- vec_xst(t6, 0, boffset+4);
2835
- vec_xst(t7, 0, boffset+8);
2836
- vec_xst(t8, 0, boffset+12);
2837
-
2838
- t1 = vec_mergel(c1[0], c2[0]);
2839
- t2 = vec_mergel(c3[0], c4[0]);
2840
- t3 = vec_mergel(c5[0], c6[0]);
2841
- t4 = vec_mergel(c7[0], c8[0]);
2842
- t5 = vec_xxpermdi(t1, t2, 0);
2843
- t6 = vec_xxpermdi(t3, t4, 0);
2844
- t7 = vec_xxpermdi(t1, t2, 3);
2845
- t8 = vec_xxpermdi(t3, t4, 3);
2846
- vec_xst(t5, 0, boffset+16);
2847
- vec_xst(t6, 0, boffset+20);
2848
- vec_xst(t7, 0, boffset+24);
2849
- vec_xst(t8, 0, boffset+28);
2939
+ for (int it = 0; it < 8 ; it++)
2940
+ c1[it] = vec_xl(0, aoffsets[it]);
2941
+ vector_permute_store_8(c1, boffset);
2850
2942
  }
2851
2943
  j--;
2852
2944
  } while(j > 0);
2853
2945
  }
2854
2946
 
2855
2947
  if (rows & 4) {
2856
- aoffset1 = aoffset;
2857
- aoffset2 = aoffset1 + lda;
2858
- aoffset3 = aoffset2 + lda;
2859
- aoffset4 = aoffset3 + lda;
2948
+ aoffsets[0] = aoffset;
2949
+ for (int it = 1; it < 4; it++)
2950
+ aoffsets[it] = aoffsets[it-1] + lda;
2860
2951
  aoffset += 4 * lda;
2861
2952
  i = (cols >> 3);
2862
2953
  if (i > 0) {
2863
2954
  do {
2864
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2865
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2866
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2867
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2868
- __builtin_vsx_disassemble_pair(c1, &C1);
2869
- __builtin_vsx_disassemble_pair(c2, &C2);
2870
- __builtin_vsx_disassemble_pair(c3, &C3);
2871
- __builtin_vsx_disassemble_pair(c4, &C4);
2872
-
2873
- t1 = vec_mergeh(c1[0], c2[0]);
2874
- t2 = vec_mergeh(c3[0], c4[0]);
2875
- t3 = vec_mergel(c1[0], c2[0]);
2876
- t4 = vec_mergel(c3[0], c4[0]);
2877
- t5 = vec_xxpermdi(t1, t2, 0);
2878
- t6 = vec_xxpermdi(t1, t2, 3);
2879
- t7 = vec_xxpermdi(t3, t4, 0);
2880
- t8 = vec_xxpermdi(t3, t4, 3);
2881
- vec_xst(t5, 0, boffset);
2882
- vec_xst(t6, 0, boffset+4);
2883
- vec_xst(t7, 0, boffset+8);
2884
- vec_xst(t8, 0, boffset+12);
2885
-
2886
- t1 = vec_mergeh(c1[1], c2[1]);
2887
- t2 = vec_mergeh(c3[1], c4[1]);
2888
- t3 = vec_mergel(c1[1], c2[1]);
2889
- t4 = vec_mergel(c3[1], c4[1]);
2890
- t5 = vec_xxpermdi(t1, t2, 0);
2891
- t6 = vec_xxpermdi(t1, t2, 3);
2892
- t7 = vec_xxpermdi(t3, t4, 0);
2893
- t8 = vec_xxpermdi(t3, t4, 3);
2894
- vec_xst(t5, 0, boffset+16);
2895
- vec_xst(t6, 0, boffset+20);
2896
- vec_xst(t7, 0, boffset+24);
2897
- vec_xst(t8, 0, boffset+28);
2898
-
2899
- aoffset1 += 8*lda;
2900
- aoffset2 += 8*lda;
2901
- aoffset3 += 8*lda;
2902
- aoffset4 += 8*lda;
2955
+ for (int it = 0; it < 4; it++) {
2956
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2957
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2958
+ c1[it] = c[it][0];
2959
+ c2[it] = c[it][1];
2960
+ }
2961
+ vector_permute_store_4(c1, boffset);
2962
+ vector_permute_store_4(c2, boffset + 16);
2963
+ for (int it = 0; it < 4; it++)
2964
+ aoffsets[it] += 8 * lda;
2903
2965
  boffset += 32;
2904
2966
  i--;
2905
2967
  } while(i > 0);
2906
2968
  }
2907
2969
 
2908
2970
  if (cols & 4) {
2909
- c1[0] = vec_xl(0, aoffset1);
2910
- c2[0] = vec_xl(0, aoffset2);
2911
- c3[0] = vec_xl(0, aoffset3);
2912
- c4[0] = vec_xl(0, aoffset4);
2913
-
2914
- t1 = vec_mergeh(c1[0], c2[0]);
2915
- t2 = vec_mergeh(c3[0], c4[0]);
2916
- t3 = vec_xxpermdi(t1, t2, 0);
2917
- t4 = vec_xxpermdi(t1, t2, 3);
2918
- vec_xst(t3, 0, boffset);
2919
- vec_xst(t4, 0, boffset+4);
2920
-
2921
- t1 = vec_mergel(c1[0], c2[0]);
2922
- t2 = vec_mergel(c3[0], c4[0]);
2923
- t3 = vec_xxpermdi(t1, t2, 0);
2924
- t4 = vec_xxpermdi(t1, t2, 3);
2925
- vec_xst(t3, 0, boffset+8);
2926
- vec_xst(t4, 0, boffset+12);
2971
+ for (int it = 0; it < 4; it++)
2972
+ c1[it] = vec_xl(0, aoffsets[it]);
2973
+ vector_permute_store_4(c1, boffset);
2927
2974
  }
2928
2975
  }
2929
2976
  if (rows & 3) {
2930
- aoffset1 = aoffset;
2931
- aoffset2 = aoffset1 + lda;
2932
- aoffset3 = aoffset2 + lda;
2977
+ aoffsets[0] = aoffset;
2978
+ for (int it = 1; it < 3; it++)
2979
+ aoffsets[it] = aoffsets[it-1] + lda;
2933
2980
  if (cols & 4) {
2934
- c1[0] = vec_xl(0, aoffset1);
2935
- c2[0] = vec_xl(0, aoffset2);
2936
- c3[0] = vec_xl(0, aoffset3);
2937
-
2938
- t1 = vec_mergeh(c1[0], c2[0]);
2939
- t2 = vec_mergeh(c3[0], c4[0]);
2940
- t3 = vec_xxpermdi(t1, t2, 0);
2941
- t4 = vec_xxpermdi(t1, t2, 3);
2942
- vec_xst(t3, 0, boffset);
2943
- vec_xst(t4, 0, boffset+4);
2944
-
2945
- t1 = vec_mergel(c1[0], c2[0]);
2946
- t2 = vec_mergel(c3[0], c4[0]);
2947
- t3 = vec_xxpermdi(t1, t2, 0);
2948
- t4 = vec_xxpermdi(t1, t2, 3);
2949
- vec_xst(t3, 0, boffset+8);
2950
- vec_xst(t4, 0, boffset+12);
2981
+ for (int it = 0; it < 3; it++)
2982
+ c1[it] = vec_xl(0, aoffsets[it]);
2983
+ vector_permute_store_4(c1, boffset);
2951
2984
  }
2952
2985
  }
2953
2986
  }
@@ -2956,15 +2989,15 @@ class tinyBLAS_PPC {
2956
2989
  vec_t vec_A[4], vec_B[4], vec_C[4];
2957
2990
  acc_t acc_0;
2958
2991
  __builtin_mma_xxsetaccz(&acc_0);
2959
- for (int l = 0; l < k; l+=4) {
2960
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2961
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2992
+ for (int l = 0; l < k; l += 4) {
2993
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2994
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2962
2995
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2963
2996
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2964
2997
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2965
2998
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2966
2999
  }
2967
- SAVE_ACC(&acc_0, ii, jj);
3000
+ save_acc(&acc_0, ii, jj);
2968
3001
  }
2969
3002
 
2970
3003
  void KERNEL_4x8(int64_t ii, int64_t jj) {
@@ -2972,9 +3005,9 @@ class tinyBLAS_PPC {
2972
3005
  acc_t acc_0, acc_1;
2973
3006
  __builtin_mma_xxsetaccz(&acc_0);
2974
3007
  __builtin_mma_xxsetaccz(&acc_1);
2975
- for (int64_t l = 0; l < k; l+=4) {
2976
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2977
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
3008
+ for (int64_t l = 0; l < k; l += 4) {
3009
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
3010
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
2978
3011
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2979
3012
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2980
3013
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2984,8 +3017,8 @@ class tinyBLAS_PPC {
2984
3017
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
2985
3018
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
2986
3019
  }
2987
- SAVE_ACC(&acc_0, ii, jj);
2988
- SAVE_ACC(&acc_1, ii, jj+4);
3020
+ save_acc(&acc_0, ii, jj);
3021
+ save_acc(&acc_1, ii, jj + 4);
2989
3022
  }
2990
3023
 
2991
3024
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2993,9 +3026,9 @@ class tinyBLAS_PPC {
2993
3026
  acc_t acc_0, acc_1;
2994
3027
  __builtin_mma_xxsetaccz(&acc_0);
2995
3028
  __builtin_mma_xxsetaccz(&acc_1);
2996
- for (int64_t l = 0; l < k; l+=4) {
2997
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2998
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
3029
+ for (int64_t l = 0; l < k; l += 4) {
3030
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
3031
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2999
3032
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
3000
3033
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
3001
3034
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -3005,8 +3038,8 @@ class tinyBLAS_PPC {
3005
3038
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
3006
3039
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
3007
3040
  }
3008
- SAVE_ACC(&acc_0, ii, jj);
3009
- SAVE_ACC(&acc_1, ii+4, jj);
3041
+ save_acc(&acc_0, ii, jj);
3042
+ save_acc(&acc_1, ii + 4, jj);
3010
3043
  }
3011
3044
 
3012
3045
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -3017,173 +3050,132 @@ class tinyBLAS_PPC {
3017
3050
  __builtin_mma_xxsetaccz(&acc_2);
3018
3051
  __builtin_mma_xxsetaccz(&acc_3);
3019
3052
  for (int l = 0; l < k; l+=8) {
3020
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
3021
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
3053
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
3054
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
3022
3055
  for(int x = 0; x < 16; x+=2) {
3023
3056
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
3024
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
3025
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
3026
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
3057
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
3058
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
3059
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
3060
+ }
3061
+ }
3062
+ save_acc(&acc_0, ii, jj);
3063
+ save_acc(&acc_1, ii, jj + 4);
3064
+ save_acc(&acc_2, ii + 4, jj);
3065
+ save_acc(&acc_3, ii + 4, jj + 4);
3066
+ }
3067
+
3068
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
3069
+ for (int x = 0; x < 16; x += 2) {
3070
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
3071
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
3072
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
3073
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
3074
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
3075
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
3076
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
3077
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
3078
+ }
3079
+ }
3080
+
3081
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
3082
+ for (int64_t i = 0; i < mc; i += 16) {
3083
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
3084
+ for (int64_t j = 0; j < nc; j += 8) {
3085
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
3086
+ acc_t acc[8];
3087
+ vec_t A0_block[16]; vec_t A1_block[16];
3088
+ for (int x = 0; x < 8; x++)
3089
+ __builtin_mma_xxsetaccz(&acc[x]);
3090
+ for (int64_t l = 0; l < kc; l += 8) {
3091
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
3092
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
3093
+ int B_block_idx = B_base_addr + (l / 8) * 16;
3094
+ vec_t* A0_block = &vec_A[A0_block_idx];
3095
+ vec_t* A1_block = &vec_A[A1_block_idx];
3096
+ vec_t* B_block = &vec_B[B_block_idx];
3097
+ MMA_16x8(A0_block, A1_block, B_block, acc);
3098
+ }
3099
+ if (kk == 0) {
3100
+ save_acc(&acc[0], ii + i, jj + j);
3101
+ save_acc(&acc[1], ii + i, jj + j + 4);
3102
+ save_acc(&acc[2], ii + i + 4, jj + j);
3103
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
3104
+ save_acc(&acc[4], ii + i + 8, jj + j);
3105
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
3106
+ save_acc(&acc[6], ii + i + 12, jj + j);
3107
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
3108
+ } else {
3109
+ add_save_acc(&acc[0], ii + i, jj + j);
3110
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
3111
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
3112
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
3113
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
3114
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
3115
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
3116
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
3117
+ }
3118
+ }
3119
+ }
3120
+ }
3121
+
3122
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
3123
+ int64_t ytiles = m / mc;
3124
+ int64_t xtiles = n / nc;
3125
+ int64_t tiles = xtiles * ytiles;
3126
+ int64_t duty = (tiles + nth - 1) / nth;
3127
+ int64_t start = duty * ith;
3128
+ int64_t end = start + duty;
3129
+ if (end > tiles) {
3130
+ end = tiles;
3131
+ }
3132
+ for (int64_t job = start; job < end; ++job) {
3133
+ int64_t ii = (job / xtiles) * mc;
3134
+ int64_t jj = (job % xtiles) * nc;
3135
+ for (int64_t kk = 0; kk < k; kk += kc) {
3136
+ vec_t A_pack[kc * mc / 4];
3137
+ vec_t B_pack[kc * nc / 4];
3138
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
3139
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
3140
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
3027
3141
  }
3028
3142
  }
3029
- SAVE_ACC(&acc_0, ii, jj);
3030
- SAVE_ACC(&acc_1, ii, jj+4);
3031
- SAVE_ACC(&acc_2, ii+4, jj);
3032
- SAVE_ACC(&acc_3, ii+4, jj+4);
3033
3143
  }
3034
3144
 
3035
3145
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3036
- int64_t mc, nc, mp, np;
3037
- int m_rem = MIN(m - m0, 16);
3038
- int n_rem = MIN(n - n0, 16);
3039
- if (m_rem >= 16 && n_rem >= 8) {
3040
- mc = 8;
3041
- nc = 8;
3042
- gemm<8,8>(m0, m, n0, n);
3043
- } else if(m_rem >= 8 && n_rem >= 16) {
3044
- mc = 8;
3045
- nc = 8;
3046
- gemm<8,8>(m0, m, n0, n);
3047
- } else if (m_rem >= 8 && n_rem >= 8) {
3146
+ int m_rem = MIN(m - m0, 8);
3147
+ int n_rem = MIN(n - n0, 8);
3148
+ int mc = 0, nc = 0;
3149
+ if (m_rem >= 8 && n_rem >= 8) {
3048
3150
  mc = 8;
3049
3151
  nc = 8;
3050
- gemm<8,8>(m0, m, n0, n);
3152
+ gemm<8, 8>(m0, m, n0, n);
3051
3153
  } else if (m_rem >= 4 && n_rem >= 8) {
3052
3154
  mc = 4;
3053
3155
  nc = 8;
3054
- gemm<4,8>(m0, m, n0, n);
3156
+ gemm<4, 8>(m0, m, n0, n);
3055
3157
  } else if (m_rem >= 8 && n_rem >= 4) {
3056
3158
  mc = 8;
3057
3159
  nc = 4;
3058
- gemm<8,4>(m0, m, n0, n);
3160
+ gemm<8, 4>(m0, m, n0, n);
3059
3161
  } else if (m_rem >= 4 && n_rem >= 4) {
3060
3162
  mc = 4;
3061
3163
  nc = 4;
3062
- gemm<4,4>(m0, m, n0, n);
3063
- } else if ((m_rem < 4) && (n_rem > 4)) {
3064
- nc = 4;
3065
- switch(m_rem) {
3066
- case 1:
3067
- mc = 1;
3068
- gemm_small(m0, m, n0, n, mc, nc);
3069
- break;
3070
- case 2:
3071
- mc = 2;
3072
- gemm_small(m0, m, n0, n, mc, nc);
3073
- break;
3074
- case 3:
3075
- mc = 3;
3076
- gemm_small(m0, m, n0, n, mc, nc);
3077
- break;
3078
- default:
3079
- return;
3080
- }
3081
- } else if ((m_rem > 4) && (n_rem < 4)) {
3082
- mc = 4;
3083
- switch(n_rem) {
3084
- case 1:
3085
- nc = 1;
3086
- gemm_small(m0, m, n0, n, mc, nc);
3087
- break;
3088
- case 2:
3089
- nc = 2;
3090
- gemm_small(m0, m, n0, n, mc, nc);
3091
- break;
3092
- case 3:
3093
- nc = 3;
3094
- gemm_small(m0, m, n0, n, mc, nc);
3095
- break;
3096
- default:
3097
- return;
3098
- }
3164
+ gemm<4, 4>(m0, m, n0, n);
3099
3165
  } else {
3100
- switch((m_rem << 4) | n_rem) {
3101
- case 0x43:
3102
- mc = 4;
3103
- nc = 3;
3104
- gemm_small(m0, m, n0, n, mc, nc);
3105
- break;
3106
- case 0x42:
3107
- mc = 4;
3108
- nc = 2;
3109
- gemm_small(m0, m, n0, n, mc, nc);
3110
- break;
3111
- case 0x41:
3112
- mc = 4;
3113
- nc = 1;
3114
- gemm_small(m0, m, n0, n, mc, nc);
3115
- break;
3116
- case 0x34:
3117
- mc = 3;
3118
- nc = 4;
3119
- gemm_small(m0, m, n0, n, mc, nc);
3120
- break;
3121
- case 0x33:
3122
- mc = 3;
3123
- nc = 3;
3124
- gemm_small(m0, m, n0, n, mc, nc);
3125
- break;
3126
- case 0x32:
3127
- mc = 3;
3128
- nc = 2;
3129
- gemm_small(m0, m, n0, n, mc, nc);
3130
- break;
3131
- case 0x31:
3132
- mc = 3;
3133
- nc = 1;
3134
- gemm_small(m0, m, n0, n, mc, nc);
3135
- break;
3136
- case 0x24:
3137
- mc = 2;
3138
- nc = 4;
3139
- gemm_small(m0, m, n0, n, mc, nc);
3140
- break;
3141
- case 0x23:
3142
- mc = 2;
3143
- nc = 3;
3144
- gemm_small(m0, m, n0, n, mc, nc);
3145
- break;
3146
- case 0x22:
3147
- mc = 2;
3148
- nc = 2;
3149
- gemm_small(m0, m, n0, n, mc, nc);
3150
- break;
3151
- case 0x21:
3152
- mc = 2;
3153
- nc = 1;
3154
- gemm_small(m0, m, n0, n, mc, nc);
3155
- break;
3156
- case 0x14:
3157
- mc = 1;
3158
- nc = 4;
3159
- gemm_small(m0, m, n0, n, mc, nc);
3160
- break;
3161
- case 0x13:
3162
- mc = 1;
3163
- nc = 3;
3164
- gemm_small(m0, m, n0, n, mc, nc);
3165
- break;
3166
- case 0x12:
3167
- mc = 1;
3168
- nc = 2;
3169
- gemm_small(m0, m, n0, n, mc, nc);
3170
- break;
3171
- case 0x11:
3172
- mc = 1;
3173
- nc = 1;
3174
- gemm_small(m0, m, n0, n, mc, nc);
3175
- break;
3176
- default:
3177
- return;
3178
- }
3166
+ mc = (m_rem >= 4) ? 4 : m_rem;
3167
+ nc = (n_rem >= 4) ? 4 : n_rem;
3168
+ if (mc == 0 || nc == 0)
3169
+ return;
3170
+ gemm_small(m0, m, n0, n, mc, nc);
3179
3171
  }
3180
- mp = m0 + (m - m0) / mc * mc;
3181
- np = n0 + (n - n0) / nc * nc;
3172
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
3173
+ int64_t np = n0 + ((n - n0) / nc) * nc;
3182
3174
  mnpack(mp, m, n0, np);
3183
3175
  mnpack(m0, m, np, n);
3184
3176
  }
3185
3177
 
3186
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3178
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3187
3179
  int64_t ytiles = (m - m0) / RM;
3188
3180
  int64_t xtiles = (n - n0) / RN;
3189
3181
  int64_t tiles = xtiles * ytiles;
@@ -3198,30 +3190,30 @@ class tinyBLAS_PPC {
3198
3190
  vec_t vec_C[4];
3199
3191
  acc_t acc_0;
3200
3192
  __builtin_mma_xxsetaccz(&acc_0);
3201
- vec_t vec_A[4] {0}, vec_B[4] = {0};
3202
- for (int l=0; l<k; l+=4) {
3193
+ vec_t vec_A[4] = {0}, vec_B[4] = {0};
3194
+ for (int l = 0; l < k; l += 4) {
3203
3195
  /* 'GEMV Forwarding' concept is used in first two conditional loops.
3204
3196
  * when one of the matrix has a single row/column, the elements are
3205
3197
  * broadcasted, instead of using packing routine to prepack the
3206
3198
  * matrix elements.
3207
3199
  */
3208
3200
  if (RM == 1) {
3209
- TA* a = const_cast<TA*>(A+(ii)*lda+l);
3210
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3201
+ float * a = const_cast<float *>(A + (ii) * lda + l);
3202
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3211
3203
  vec_A[0] = (vec_t)vec_xl(0,a);
3212
- vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
3213
- vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
3214
- vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
3204
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
3205
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
3206
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
3215
3207
  } else if (RN == 1) {
3216
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3217
- TB* b = const_cast<TB*>(B+(jj)*ldb+l);
3208
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3209
+ float * b = const_cast<float *>(B + (jj) * ldb + l);
3218
3210
  vec_B[0] = (vec_t)vec_xl(0,b);
3219
- vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3220
- vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3221
- vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
3211
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
3212
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
3213
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
3222
3214
  } else {
3223
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3224
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3215
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3216
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3225
3217
  }
3226
3218
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3227
3219
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -3231,12 +3223,27 @@ class tinyBLAS_PPC {
3231
3223
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
3232
3224
  for (int I = 0; I < RM; I++) {
3233
3225
  for (int J = 0; J < RN; J++) {
3234
- *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
3226
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
3235
3227
  }
3236
3228
  }
3237
3229
  }
3238
3230
  }
3239
3231
 
3232
+ template<int RM, int RN>
3233
+ inline void kernel(int64_t ii, int64_t jj) {
3234
+ if constexpr(RM == 4 && RN == 4) {
3235
+ KERNEL_4x4(ii, jj);
3236
+ } else if constexpr(RM == 4 && RN == 8) {
3237
+ KERNEL_4x8(ii, jj);
3238
+ } else if constexpr(RM == 8 && RN == 4) {
3239
+ KERNEL_8x4(ii, jj);
3240
+ } else if constexpr(RM == 8 && RN == 8) {
3241
+ KERNEL_8x8(ii, jj);
3242
+ } else {
3243
+ static_assert(false, "RN/RM values not supported");
3244
+ }
3245
+ }
3246
+
3240
3247
  template <int RM, int RN>
3241
3248
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3242
3249
  int64_t ytiles = (m - m0) / RM;
@@ -3245,29 +3252,18 @@ class tinyBLAS_PPC {
3245
3252
  int64_t duty = (tiles + nth - 1) / nth;
3246
3253
  int64_t start = duty * ith;
3247
3254
  int64_t end = start + duty;
3248
- if (RM == 4 && RN == 4) {
3249
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
3250
- } else if (RM == 4 && RN == 8) {
3251
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
3252
- } else if (RM == 8 && RN == 4) {
3253
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
3254
- } else if (RM == 8 && RN == 8) {
3255
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
3256
- }
3257
3255
  if (end > tiles)
3258
3256
  end = tiles;
3259
3257
  for (int64_t job = start; job < end; ++job) {
3260
3258
  int64_t ii = m0 + job / xtiles * RM;
3261
3259
  int64_t jj = n0 + job % xtiles * RN;
3262
- (this->*kernel)(ii, jj);
3260
+ kernel<RM, RN>(ii, jj);
3263
3261
  }
3264
3262
  }
3265
3263
 
3266
- const TA *const A;
3267
- const TB *const B;
3268
- TC *C;
3269
- TA *At;
3270
- TB *Bt;
3264
+ const float * const A;
3265
+ const float * const B;
3266
+ float * C;
3271
3267
  const int64_t k;
3272
3268
  const int64_t lda;
3273
3269
  const int64_t ldb;
@@ -3366,13 +3362,31 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3366
3362
  #elif defined(__MMA__)
3367
3363
  if (k % 8)
3368
3364
  return false;
3369
- tinyBLAS_PPC<float, float, float> tb{
3365
+ tinyBLAS_PPC tb{
3370
3366
  k, (const float *)A, lda,
3371
3367
  (const float *)B, ldb,
3372
3368
  (float *)C, ldc,
3373
3369
  params->ith, params->nth};
3374
3370
  tb.matmul(m, n);
3375
3371
  return true;
3372
+ #elif defined(__riscv_zvfh)
3373
+ #if LMUL == 1
3374
+ tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
3375
+ k, (const float *)A, lda,
3376
+ (const float *)B, ldb,
3377
+ (float *)C, ldc};
3378
+ #elif LMUL == 2
3379
+ tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
3380
+ k, (const float *)A, lda,
3381
+ (const float *)B, ldb,
3382
+ (float *)C, ldc};
3383
+ #else // LMUL = 4
3384
+ tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
3385
+ k, (const float *)A, lda,
3386
+ (const float *)B, ldb,
3387
+ (float *)C, ldc};
3388
+ #endif
3389
+ return tb.matmul(m, n);
3376
3390
  #else
3377
3391
  return false;
3378
3392
  #endif
@@ -3415,6 +3429,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3415
3429
  tb.matmul(m, n);
3416
3430
  return true;
3417
3431
  }
3432
+ #elif defined(__riscv_zvfbfwma)
3433
+ #if LMUL == 1
3434
+ tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3435
+ k, (const ggml_bf16_t *)A, lda,
3436
+ (const ggml_bf16_t *)B, ldb,
3437
+ (float *)C, ldc};
3438
+ #elif LMUL == 2
3439
+ tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3440
+ k, (const ggml_bf16_t *)A, lda,
3441
+ (const ggml_bf16_t *)B, ldb,
3442
+ (float *)C, ldc};
3443
+ #else // LMUL = 4
3444
+ tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3445
+ k, (const ggml_bf16_t *)A, lda,
3446
+ (const ggml_bf16_t *)B, ldb,
3447
+ (float *)C, ldc};
3448
+ #endif
3449
+ return tb.matmul(m, n);
3418
3450
  #endif
3419
3451
  return false;
3420
3452
  }
@@ -3464,6 +3496,26 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3464
3496
  (float *)C, ldc};
3465
3497
  return tb.matmul(m, n);
3466
3498
  }
3499
+ #elif defined(__riscv_zvfh)
3500
+ if (Btype == GGML_TYPE_F16) {
3501
+ #if LMUL == 1
3502
+ tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3503
+ k, (const ggml_fp16_t *)A, lda,
3504
+ (const ggml_fp16_t *)B, ldb,
3505
+ (float *)C, ldc};
3506
+ #elif LMUL == 2
3507
+ tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3508
+ k, (const ggml_fp16_t *)A, lda,
3509
+ (const ggml_fp16_t *)B, ldb,
3510
+ (float *)C, ldc};
3511
+ #else // LMUL = 4
3512
+ tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3513
+ k, (const ggml_fp16_t *)A, lda,
3514
+ (const ggml_fp16_t *)B, ldb,
3515
+ (float *)C, ldc};
3516
+ #endif
3517
+ return tb.matmul(m, n);
3518
+ }
3467
3519
  #endif
3468
3520
  return false;
3469
3521
  }
@@ -3493,7 +3545,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3493
3545
  return false;
3494
3546
  if (m < 8 && m != 4)
3495
3547
  return false;
3496
- tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
3548
+ tinyBLAS_Q0_PPC<block_q8_0> tb{
3497
3549
  k, (const block_q8_0 *)A, lda,
3498
3550
  (const block_q8_0 *)B, ldb,
3499
3551
  (float *)C, ldc,
@@ -3530,7 +3582,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3530
3582
  return false;
3531
3583
  if (m < 8 && m != 4)
3532
3584
  return false;
3533
- tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
3585
+ tinyBLAS_Q0_PPC<block_q4_0> tb{
3534
3586
  k, (const block_q4_0 *)A, lda,
3535
3587
  (const block_q8_0 *)B, ldb,
3536
3588
  (float *)C, ldc,