whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
121
121
  #endif
122
122
 
123
123
  #if defined(__MMA__)
124
- #include "sgemm-ppc.h"
124
+ typedef vector unsigned char vec_t;
125
+ typedef __vector_quad acc_t;
125
126
  #endif
126
127
  ////////////////////////////////////////////////////////////////////////////////////////////////////
127
128
  // VECTORIZED FUSED MULTIPLY ADD
@@ -532,7 +533,7 @@ class tinyBLAS {
532
533
  if constexpr (RN > 1) {
533
534
  return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
534
535
  } else {
535
- GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
536
+ GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
536
537
  GGML_ASSERT(false); // we have miss something.
537
538
  }
538
539
  }
@@ -710,7 +711,7 @@ class tinyBLAS_RVV {
710
711
  if constexpr (RN > 1) {
711
712
  return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
712
713
  } else {
713
- GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
714
+ GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
714
715
  GGML_ASSERT(false); // we have miss something.
715
716
  }
716
717
  }
@@ -1797,10 +1798,27 @@ class tinyBLAS_Q0_AVX {
1797
1798
  } \
1798
1799
  } \
1799
1800
 
1801
+ template<typename T>
1802
+ struct mma_instr;
1803
+
1804
+ template<>
1805
+ struct mma_instr<ggml_bf16_t> {
1806
+ static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1807
+ __builtin_mma_xvbf16ger2pp(acc, a, b);
1808
+ }
1809
+ };
1810
+
1811
+ template<>
1812
+ struct mma_instr<ggml_fp16_t> {
1813
+ static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1814
+ __builtin_mma_xvf16ger2pp(acc, a, b);
1815
+ }
1816
+ };
1817
+
1800
1818
  template <typename TA, typename TB, typename TC>
1801
- class tinyBLAS_BF16_PPC {
1819
+ class tinyBLAS_HP16_PPC {
1802
1820
  public:
1803
- tinyBLAS_BF16_PPC(int64_t k,
1821
+ tinyBLAS_HP16_PPC(int64_t k,
1804
1822
  const TA *A, int64_t lda,
1805
1823
  const TB *B, int64_t ldb,
1806
1824
  TC *C, int64_t ldc,
@@ -2118,8 +2136,8 @@ class tinyBLAS_BF16_PPC {
2118
2136
  packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
2119
2137
  packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
2120
2138
  for (int x = 0; x < 4; x++) {
2121
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2122
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
2139
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2140
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2123
2141
  }
2124
2142
  }
2125
2143
  SAVE_ACC(&acc_0, ii, jj);
@@ -2135,8 +2153,8 @@ class tinyBLAS_BF16_PPC {
2135
2153
  packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
2136
2154
  packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
2137
2155
  for (int x = 0; x < 4; x++) {
2138
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2139
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
2156
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2157
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
2140
2158
  }
2141
2159
  }
2142
2160
  SAVE_ACC(&acc_0, ii, jj);
@@ -2155,10 +2173,10 @@ class tinyBLAS_BF16_PPC {
2155
2173
  packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
2156
2174
  packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
2157
2175
  for (int x = 0; x < 4; x++) {
2158
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2159
- __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
2160
- __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
2161
- __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
2176
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2177
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2178
+ mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
2179
+ mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
2162
2180
  }
2163
2181
  }
2164
2182
 
@@ -2189,7 +2207,7 @@ class tinyBLAS_BF16_PPC {
2189
2207
  packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
2190
2208
  packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
2191
2209
  for (int x = 0; x<2; x++) {
2192
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2210
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2193
2211
  }
2194
2212
  }
2195
2213
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -2224,8 +2242,8 @@ class tinyBLAS_BF16_PPC {
2224
2242
  packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
2225
2243
  packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
2226
2244
  for (int x = 0; x<4; x++) {
2227
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2228
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
2245
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2246
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2229
2247
  }
2230
2248
  }
2231
2249
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -2284,43 +2302,299 @@ class tinyBLAS_BF16_PPC {
2284
2302
  const int nth;
2285
2303
  };
2286
2304
 
2287
- template <typename TA>
2288
- tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
2289
- const TA *A, int64_t lda,
2290
- const block_q8_0 *B, int64_t ldb,
2291
- float *C, int64_t ldc,
2292
- int ith, int nth)
2305
+ template <typename TA>
2306
+ class tinyBLAS_Q0_PPC {
2307
+ public:
2308
+ tinyBLAS_Q0_PPC(int64_t k,
2309
+ const TA * A, int64_t lda,
2310
+ const block_q8_0 * B, int64_t ldb,
2311
+ float * C, int64_t ldc,
2312
+ int ith, int nth)
2293
2313
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2294
- kc = 64;
2295
2314
  }
2296
2315
 
2297
- template<typename TA>
2298
- void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
2299
- int mc = 64; int nc = 64;
2300
- if (n % 8 == 0 && n < nc) {
2301
- nc = n;
2302
- mc = 32 ;
2303
- kc = 32;
2304
- }
2305
- const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
2306
- if (is_aligned) {
2307
- this->matmul_tiled_q0(m, n, mc, nc, kc);
2316
+ void matmul(int64_t m, int64_t n) {
2317
+ const int64_t mc = 64;
2318
+ const int64_t kc = 64;
2319
+ int64_t nc = 64;
2320
+ int64_t n_aligned = 0;
2321
+ if (n % 64 == 0) {
2322
+ n_aligned = n;
2323
+ } else if (n == 4) {
2324
+ n_aligned = 4;
2325
+ } else if (n < 64) {
2326
+ n_aligned = (n / 8) * 8;
2327
+ } else {
2328
+ n_aligned = (n / 64) * 64;
2329
+ }
2330
+
2331
+ if (n_aligned > 0) {
2332
+ if (n_aligned % 64 == 0) nc = 64;
2333
+ else if (n_aligned == n) nc = n;
2334
+ else if (n_aligned % 32 == 0) nc = 32;
2335
+ else if (n_aligned % 24 == 0) nc = 24;
2336
+ else if (n_aligned % 16 == 0) nc = 16;
2337
+ else nc = 8;
2338
+ }
2339
+ bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
2340
+ if (can_use_tiled) {
2341
+ matmul_tiled(m, n_aligned, mc, nc, kc);
2342
+ if (n > n_aligned) {
2343
+ mnpack(0, m, n_aligned, n);
2344
+ }
2308
2345
  } else {
2309
2346
  mnpack(0, m, 0, n);
2310
2347
  }
2311
2348
  }
2312
2349
 
2313
- template<typename TA>
2314
- template<int size>
2315
- void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
2350
+ private:
2351
+ inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
2352
+ for (int I = 0; I < RM; I++) {
2353
+ for (int J = 0; J < RN; J++) {
2354
+ *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
2355
+ }
2356
+ }
2357
+ }
2358
+
2359
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2360
+ vec_t vec_C[4];
2361
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2362
+ for (int I = 0; I < 4; I++) {
2363
+ for (int J = 0; J < 4; J++) {
2364
+ *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
2365
+ }
2366
+ }
2367
+ }
2368
+
2369
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2370
+ vec_t vec_C[4];
2371
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2372
+ for (int I = 0; I < 4; I++) {
2373
+ for (int J = 0; J < 4; J++) {
2374
+ float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
2375
+ *c_ptr += *((float *)&vec_C[I] + J);
2376
+ }
2377
+ }
2378
+ }
2379
+
2380
+ template<typename ArrayType>
2381
+ inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
2382
+ vector signed int vec_C[4];
2383
+ vector float CA[4] = {0};
2384
+ vector float res[4] = {0};
2385
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2386
+ for (int i = 0; i < 4; i++) {
2387
+ CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
2388
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
2389
+ fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
2390
+ }
2391
+ }
2392
+
2393
+ inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
2394
+ const vector signed char lowMask = vec_splats((signed char)0xF);
2395
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
2396
+ const vector signed char v8 = vec_splats((signed char)0x8);
2397
+ vector signed int vsum = {0};
2398
+ vector signed int vsum2 = {0};
2399
+ c[0] = vec_and(c[1], lowMask);
2400
+ c[1] = vec_sr(c[1], v4);
2401
+ c[0] = vec_sub(c[0], v8);
2402
+ c[1] = vec_sub(c[1], v8);
2403
+ vsum = vec_sum4s(c[0], vsum);
2404
+ vsum2 = vec_sum4s(c[1], vsum2);
2405
+ vsum = vec_add(vsum, vsum2);
2406
+ *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
2407
+ }
2408
+
2409
+ template <typename V1, typename V2>
2410
+ inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
2411
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2412
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2413
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
2414
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
2415
+ V2 t1, t2, t3, t4, t5, t6, t7, t8;
2416
+ vector unsigned char xor_vector;
2417
+ uint8_t flip_vec = 0x80;
2418
+ xor_vector = vec_splats(flip_vec);
2419
+ t1 = vec_perm(s1, s2, swiz1);
2420
+ t2 = vec_perm(s1, s2, swiz2);
2421
+ t3 = vec_perm(s3, s4, swiz1);
2422
+ t4 = vec_perm(s3, s4, swiz2);
2423
+ t5 = vec_perm(t1, t3, swiz3);
2424
+ t6 = vec_perm(t1, t3, swiz4);
2425
+ t7 = vec_perm(t2, t4, swiz3);
2426
+ t8 = vec_perm(t2, t4, swiz4);
2427
+ if (flip == true) {
2428
+ t5 = vec_xor(t5, xor_vector);
2429
+ t6 = vec_xor(t6, xor_vector);
2430
+ t7 = vec_xor(t7, xor_vector);
2431
+ t8 = vec_xor(t8, xor_vector);
2432
+ }
2433
+ vec_xst(t5, 0, vecOffset);
2434
+ vec_xst(t6, 0, vecOffset + 16);
2435
+ vec_xst(t7, 0, vecOffset + 32);
2436
+ vec_xst(t8, 0, vecOffset + 48);
2437
+ }
2438
+
2439
+ inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
2440
+ const vector signed char lowMask = vec_splats((signed char)0x0F);
2441
+ const vector signed char v8 = vec_splats((signed char)0x08);
2442
+ const vector unsigned char v4 = vec_splats((unsigned char)4);
2443
+ lo = vec_and(packed, lowMask);
2444
+ hi = vec_sr(packed, v4);
2445
+ lo = vec_sub(lo, v8);
2446
+ hi = vec_sub(hi, v8);
2447
+ }
2448
+
2449
+ inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
2450
+ vec_t t[8], s[8];
2451
+ vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
2452
+ vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
2453
+ vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2454
+ vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2455
+ for (int i = 0; i < 4; i += 2) {
2456
+ t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
2457
+ t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
2458
+ }
2459
+ for (int i = 4; i < 8; i += 2) {
2460
+ t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
2461
+ t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
2462
+ }
2463
+ s[0] = vec_perm(t[0], t[2], swiz3);
2464
+ s[1] = vec_perm(t[0], t[2], swiz4);
2465
+ s[2] = vec_perm(t[1], t[3], swiz3);
2466
+ s[3] = vec_perm(t[1], t[3], swiz4);
2467
+ s[4] = vec_perm(t[4], t[6], swiz3);
2468
+ s[5] = vec_perm(t[4], t[6], swiz4);
2469
+ s[6] = vec_perm(t[5], t[7], swiz3);
2470
+ s[7] = vec_perm(t[5], t[7], swiz4);
2471
+ for (int i = 0; i < 8; ++i) {
2472
+ vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
2473
+ }
2474
+ }
2475
+
2476
+ static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
2477
+ vector signed short i16_hi = vec_unpackh(raw);
2478
+ vector signed short i16_lo = vec_unpackl(raw);
2479
+
2480
+ vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
2481
+ vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
2482
+ vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
2483
+ vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
2484
+ out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
2485
+ out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
2486
+ }
2487
+
2488
+ void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2489
+ unsigned char * vecOffset = vec;
2490
+ for (int i = 0; i < rows; i += 8) {
2491
+ const block_q4_0 * rows_base[8];
2492
+ for (int r = 0; r < 8; r++) {
2493
+ rows_base[r] = a + (i + r) * lda;
2494
+ }
2495
+ for (int blk = 0; blk < blocks; blk++) {
2496
+ vector unsigned short hp_res[8][4];
2497
+ for (int r = 0; r < 8; r++) {
2498
+ const block_q4_0 * current_blk = rows_base[r] + blk;
2499
+ vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
2500
+ vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);
2501
+ vector signed char c1, c2;
2502
+ unpack_q4_to_q8(v_qs, c1, c2);
2503
+ convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
2504
+ convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
2505
+ }
2506
+ for (int c = 0; c < 4; c++) {
2507
+ vector unsigned char c_arr[8];
2508
+ for (int r = 0; r < 8; r++) {
2509
+ c_arr[r] = (vector unsigned char)hp_res[r][c];
2510
+ }
2511
+ vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
2512
+ vecOffset += 128;
2513
+ }
2514
+ }
2515
+ }
2516
+ }
2517
+
2518
+ template <int chunk_size>
2519
+ static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2520
+ unsigned char * vecOffset = vec;
2521
+ const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
2522
+ const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
2523
+ const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2524
+ const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2525
+
2526
+ for (int i = 0; i < rows; i += chunk_size) {
2527
+ const block_q8_0 * rows_base[chunk_size];
2528
+ for (int r = 0; r < chunk_size; r++) {
2529
+ rows_base[r] = a + (i + r) * lda;
2530
+ }
2531
+ for (int blk = 0; blk < blocks; blk++) {
2532
+ vector unsigned short hp_res[chunk_size][4];
2533
+ for (int r = 0; r < chunk_size; r++) {
2534
+ const block_q8_0 * b = rows_base[r] + blk;
2535
+ vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
2536
+ vector signed char c[2];
2537
+ __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
2538
+ __builtin_vsx_disassemble_pair(c, & pair);
2539
+ convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
2540
+ convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
2541
+ }
2542
+ for (int col = 0; col < 4; col++) {
2543
+ if constexpr (chunk_size == 8) {
2544
+ vec_t t[8];
2545
+ t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
2546
+ t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
2547
+ t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
2548
+ t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
2549
+ t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
2550
+ t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
2551
+ t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
2552
+ t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
2553
+
2554
+ vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
2555
+ vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
2556
+ vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
2557
+ vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
2558
+ vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
2559
+ vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
2560
+ vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
2561
+ vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
2562
+ vecOffset += 128;
2563
+ } else {
2564
+ vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
2565
+ vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
2566
+ vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
2567
+ vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
2568
+
2569
+ vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
2570
+ vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
2571
+ vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
2572
+ vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
2573
+ vecOffset += 64;
2574
+ }
2575
+ }
2576
+ }
2577
+ }
2578
+ }
2579
+
2580
+ void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2581
+ if (rows == 4) {
2582
+ pack_q8_block<4>(a, lda, rows, blocks, vec);
2583
+ } else {
2584
+ pack_q8_block<8>(a, lda, rows, blocks, vec);
2585
+ }
2586
+ }
2587
+
2588
+ template<int size>
2589
+ void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
2316
2590
  int64_t i, j;
2317
- TA *aoffset = NULL;
2318
- int8_t *vecOffset = NULL;
2319
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2320
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2591
+ TA * aoffset = NULL;
2592
+ int8_t * vecOffset = NULL;
2593
+ TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
2594
+ TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
2321
2595
  vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2322
2596
  vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2323
- aoffset = const_cast<TA*>(a);
2597
+ aoffset = const_cast<TA *>(a);
2324
2598
  vecOffset = vec;
2325
2599
  j = (rows >> 3);
2326
2600
  if (j > 0) {
@@ -2337,27 +2611,27 @@ class tinyBLAS_BF16_PPC {
2337
2611
  i = (cols >> 2);
2338
2612
  if (i > 0) {
2339
2613
  do {
2340
- c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2341
- c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2342
- c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2343
- c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2344
- c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
2345
- c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
2346
- c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
2347
- c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
2348
-
2349
- process_q4_elements(c1, &comparray[0]);
2350
- process_q4_elements(c2, &comparray[1]);
2351
- process_q4_elements(c3, &comparray[2]);
2352
- process_q4_elements(c4, &comparray[3]);
2353
- process_q4_elements(c5, &comparray[4]);
2354
- process_q4_elements(c6, &comparray[5]);
2355
- process_q4_elements(c7, &comparray[6]);
2356
- process_q4_elements(c8, &comparray[7]);
2614
+ c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2615
+ c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2616
+ c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2617
+ c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
2618
+ c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);
2619
+ c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);
2620
+ c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);
2621
+ c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);
2622
+
2623
+ process_q4_elements(c1, & comparray[0]);
2624
+ process_q4_elements(c2, & comparray[1]);
2625
+ process_q4_elements(c3, & comparray[2]);
2626
+ process_q4_elements(c4, & comparray[3]);
2627
+ process_q4_elements(c5, & comparray[4]);
2628
+ process_q4_elements(c6, & comparray[5]);
2629
+ process_q4_elements(c7, & comparray[6]);
2630
+ process_q4_elements(c8, & comparray[7]);
2357
2631
  vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2358
- vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2359
- vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
2360
- vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
2632
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
2633
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
2634
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
2361
2635
  aoffset1 += lda;
2362
2636
  aoffset2 += lda;
2363
2637
  aoffset3 += lda;
@@ -2383,17 +2657,17 @@ class tinyBLAS_BF16_PPC {
2383
2657
  i = (cols >> 2);
2384
2658
  if (i > 0) {
2385
2659
  do {
2386
- c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2387
- c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2388
- c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2389
- c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
2390
-
2391
- process_q4_elements(c1, &comparray[0]);
2392
- process_q4_elements(c2, &comparray[1]);
2393
- process_q4_elements(c3, &comparray[2]);
2394
- process_q4_elements(c4, &comparray[3]);
2660
+ c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2661
+ c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2662
+ c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2663
+ c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
2664
+
2665
+ process_q4_elements(c1, & comparray[0]);
2666
+ process_q4_elements(c2, & comparray[1]);
2667
+ process_q4_elements(c3, & comparray[2]);
2668
+ process_q4_elements(c4, & comparray[3]);
2395
2669
  vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2396
- vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2670
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
2397
2671
  aoffset1 += lda;
2398
2672
  aoffset2 += lda;
2399
2673
  aoffset3 += lda;
@@ -2412,17 +2686,17 @@ class tinyBLAS_BF16_PPC {
2412
2686
  if (i > 0) {
2413
2687
  do {
2414
2688
  switch(rows) {
2415
- case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
2416
- case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
2417
- case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
2689
+ case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2690
+ case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2691
+ case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2418
2692
  break;
2419
2693
  }
2420
- process_q4_elements(c1, &comparray[0]);
2421
- process_q4_elements(c2, &comparray[1]);
2422
- process_q4_elements(c3, &comparray[2]);
2423
- process_q4_elements(c4, &comparray[3]);
2694
+ process_q4_elements(c1, & comparray[0]);
2695
+ process_q4_elements(c2, & comparray[1]);
2696
+ process_q4_elements(c3, & comparray[2]);
2697
+ process_q4_elements(c4, & comparray[3]);
2424
2698
  vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2425
- vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2699
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
2426
2700
  aoffset1 += lda;
2427
2701
  aoffset2 += lda;
2428
2702
  aoffset3 += lda;
@@ -2433,39 +2707,38 @@ class tinyBLAS_BF16_PPC {
2433
2707
  }
2434
2708
  }
2435
2709
 
2436
- template<typename TA>
2437
2710
  template<typename VA, typename VB>
2438
- void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2711
+ void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
2439
2712
  int64_t i, j;
2440
- block_q8_0 *aoffset = NULL;
2441
- VA *vecOffset = NULL;
2442
- block_q8_0* aoffsets[8];
2713
+ block_q8_0 * aoffset = NULL;
2714
+ VA * vecOffset = NULL;
2715
+ block_q8_0 * aoffsets[8];
2443
2716
  __vector_pair arr[8];
2444
2717
  VB c[8][2] = {0};
2445
2718
  VB c1[8] = {0}; VB c2[8] = {0};
2446
- aoffset = const_cast<block_q8_0*>(a);
2719
+ aoffset = const_cast<block_q8_0 *>(a);
2447
2720
  vecOffset = vec;
2448
2721
  j = (rows >> 3);
2449
2722
  if (j > 0) {
2450
2723
  do {
2451
2724
  aoffsets[0] = aoffset;
2452
2725
  for (int it = 1; it < 8; it++)
2453
- aoffsets[it] = aoffsets[it-1] + lda;
2726
+ aoffsets[it] = aoffsets[it - 1] + lda;
2454
2727
  aoffset += 8 * lda;
2455
2728
 
2456
2729
  i = (cols >> 3);
2457
2730
  if (i > 0) {
2458
2731
  do {
2459
2732
  for (int it = 0; it < 8; it++) {
2460
- arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2461
- __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2733
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
2734
+ __builtin_vsx_disassemble_pair(c[it], & arr[it]);
2462
2735
  c1[it] = c[it][0];
2463
2736
  c2[it] = c[it][1];
2464
2737
  }
2465
2738
  vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2466
- vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2467
- vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
2468
- vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
2739
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
2740
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
2741
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
2469
2742
  for (int it = 0; it < 8; it++)
2470
2743
  aoffsets[it] += lda;
2471
2744
  vecOffset += 256;
@@ -2484,13 +2757,13 @@ class tinyBLAS_BF16_PPC {
2484
2757
  if (i > 0) {
2485
2758
  do {
2486
2759
  for (int it = 0; it < 4; it++) {
2487
- arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
2488
- __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2760
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
2761
+ __builtin_vsx_disassemble_pair(c[it], & arr[it]);
2489
2762
  c1[it] = c[it][0];
2490
2763
  c2[it] = c[it][1];
2491
2764
  }
2492
2765
  vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2493
- vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2766
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
2494
2767
  for (int it = 0; it < 4; it++) {
2495
2768
  aoffsets[it] += lda;
2496
2769
  }
@@ -2503,24 +2776,24 @@ class tinyBLAS_BF16_PPC {
2503
2776
  if (rows & 3) {
2504
2777
  aoffsets[0] = aoffset;
2505
2778
  for (int it = 1; it < 3; it++ )
2506
- aoffsets[it] = aoffsets[it-1] + lda;
2779
+ aoffsets[it] = aoffsets[it - 1] + lda;
2507
2780
  i = (cols >> 3);
2508
2781
  if (i > 0) {
2509
2782
  do {
2510
2783
  switch(rows) {
2511
- case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
2512
- __builtin_vsx_disassemble_pair(c[2], &arr[2]);
2784
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
2785
+ __builtin_vsx_disassemble_pair(c[2], & arr[2]);
2513
2786
  c1[2] = c[2][0]; c2[2] = c[2][1];
2514
- case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
2515
- __builtin_vsx_disassemble_pair(c[1], &arr[1]);
2787
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
2788
+ __builtin_vsx_disassemble_pair(c[1], & arr[1]);
2516
2789
  c1[1] = c[1][0]; c2[1] = c[1][1];
2517
- case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
2518
- __builtin_vsx_disassemble_pair(c[0], &arr[0]);
2790
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
2791
+ __builtin_vsx_disassemble_pair(c[0], & arr[0]);
2519
2792
  c1[0] = c[0][0]; c2[0] = c[0][1];
2520
2793
  break;
2521
2794
  }
2522
2795
  vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2523
- vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
2796
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
2524
2797
  for (int it = 0; it < 3; it++)
2525
2798
  aoffsets[it] += lda;
2526
2799
  vecOffset += 128;
@@ -2530,8 +2803,7 @@ class tinyBLAS_BF16_PPC {
2530
2803
  }
2531
2804
  }
2532
2805
 
2533
- template<typename TA>
2534
- void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2806
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2535
2807
  int m_rem = MIN(m - m0, 16);
2536
2808
  int n_rem = MIN(n - n0, 16);
2537
2809
 
@@ -2568,8 +2840,7 @@ class tinyBLAS_BF16_PPC {
2568
2840
  }
2569
2841
 
2570
2842
 
2571
- template<typename TA>
2572
- void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
2843
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
2573
2844
  vec_t vec_A[8], vec_B[16] = {0};
2574
2845
  acc_t acc_0, acc_1;
2575
2846
  std::array<int, 4> comparray {};
@@ -2577,26 +2848,26 @@ class tinyBLAS_BF16_PPC {
2577
2848
  vector float vs[8] = {0};
2578
2849
  bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2579
2850
  for (int l = 0; l < k; l++) {
2580
- __builtin_mma_xxsetaccz(&acc_0);
2581
- __builtin_mma_xxsetaccz(&acc_1);
2851
+ __builtin_mma_xxsetaccz(& acc_0);
2852
+ __builtin_mma_xxsetaccz(& acc_1);
2582
2853
  if (std::is_same_v<TA, block_q4_0>) {
2583
- packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2854
+ packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
2584
2855
  } else {
2585
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2856
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
2586
2857
  }
2587
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2858
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
2588
2859
  for(int x = 0; x < 8; x++) {
2589
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2590
- __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
2860
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2861
+ __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
2591
2862
  }
2592
2863
  for (int I = 0; I<4; I++) {
2593
2864
  for (int J = 0; J<4; J++) {
2594
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2595
- *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2865
+ *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2866
+ *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
2596
2867
  }
2597
2868
  }
2598
2869
  if (!isAblock_q4) {
2599
- auto aoffset = A+(ii*lda)+l;
2870
+ auto aoffset = A + (ii * lda) + l;
2600
2871
  for (int i = 0; i < 4; i++) {
2601
2872
  comparray[i] = 0;
2602
2873
  int ca = 0;
@@ -2607,15 +2878,14 @@ class tinyBLAS_BF16_PPC {
2607
2878
  aoffset += lda;
2608
2879
  }
2609
2880
  }
2610
- compute(&acc_0, 0, 0, comparray, vs, fin_res);
2611
- compute(&acc_1, 0, 4, comparray, vs, fin_res);
2881
+ compute(& acc_0, 0, 0, comparray, vs, fin_res);
2882
+ compute(& acc_1, 0, 4, comparray, vs, fin_res);
2612
2883
  }
2613
2884
  save_res(ii, jj, 0, fin_res);
2614
- save_res(ii, jj+4, 4, fin_res);
2885
+ save_res(ii, jj + 4, 4, fin_res);
2615
2886
  }
2616
2887
 
2617
- template<typename TA>
2618
- void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
2888
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
2619
2889
  vec_t vec_A[16], vec_B[8] = {0};
2620
2890
  acc_t acc_0, acc_1;
2621
2891
  std::array<int, 8> comparray {};
@@ -2623,25 +2893,25 @@ class tinyBLAS_BF16_PPC {
2623
2893
  vector float vs[8] = {0};
2624
2894
  bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2625
2895
  for (int l = 0; l < k; l++) {
2626
- __builtin_mma_xxsetaccz(&acc_0);
2627
- __builtin_mma_xxsetaccz(&acc_1);
2896
+ __builtin_mma_xxsetaccz(& acc_0);
2897
+ __builtin_mma_xxsetaccz(& acc_1);
2628
2898
  if (std::is_same_v<TA, block_q4_0>) {
2629
- packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2899
+ packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
2630
2900
  } else {
2631
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2901
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
2632
2902
  }
2633
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2903
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
2634
2904
  for(int x = 0; x < 8; x++) {
2635
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2636
- __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2905
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2906
+ __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
2637
2907
  }
2638
- for (int I = 0; I<8; I++) {
2639
- for (int J = 0; J<4; J++) {
2640
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2908
+ for (int I = 0; I < 8; I++) {
2909
+ for (int J = 0; J < 4; J++) {
2910
+ *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2641
2911
  }
2642
2912
  }
2643
2913
  if (!isAblock_q4) {
2644
- auto aoffset = A+(ii*lda)+l;
2914
+ auto aoffset = A + (ii * lda) + l;
2645
2915
  for (int i = 0; i < 8; i++) {
2646
2916
  comparray[i] = 0;
2647
2917
  int ca = 0;
@@ -2652,15 +2922,14 @@ class tinyBLAS_BF16_PPC {
2652
2922
  aoffset += lda;
2653
2923
  }
2654
2924
  }
2655
- compute(&acc_0, 0, 0, comparray, vs, fin_res);
2656
- compute(&acc_1, 4, 4, comparray, vs, fin_res);
2925
+ compute(& acc_0, 0, 0, comparray, vs, fin_res);
2926
+ compute(& acc_1, 4, 4, comparray, vs, fin_res);
2657
2927
  }
2658
2928
  save_res(ii, jj, 0, fin_res);
2659
- save_res(ii+4, jj, 4, fin_res);
2929
+ save_res(ii + 4, jj, 4, fin_res);
2660
2930
  }
2661
2931
 
2662
- template<typename TA>
2663
- void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
2932
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
2664
2933
  vec_t vec_A[16], vec_B[16] = {0};
2665
2934
  acc_t acc_0, acc_1, acc_2, acc_3;
2666
2935
  acc_t acc_4, acc_5, acc_6, acc_7;
@@ -2669,30 +2938,30 @@ class tinyBLAS_BF16_PPC {
2669
2938
  vector float vs[16] = {0};
2670
2939
  bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2671
2940
  for (int l = 0; l < k; l++) {
2672
- __builtin_mma_xxsetaccz(&acc_0);
2673
- __builtin_mma_xxsetaccz(&acc_1);
2674
- __builtin_mma_xxsetaccz(&acc_2);
2675
- __builtin_mma_xxsetaccz(&acc_3);
2941
+ __builtin_mma_xxsetaccz(& acc_0);
2942
+ __builtin_mma_xxsetaccz(& acc_1);
2943
+ __builtin_mma_xxsetaccz(& acc_2);
2944
+ __builtin_mma_xxsetaccz(& acc_3);
2676
2945
  if (std::is_same_v<TA, block_q4_0>) {
2677
- packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2946
+ packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
2678
2947
  } else {
2679
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2948
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
2680
2949
  }
2681
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2950
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
2682
2951
  for(int x = 0; x < 8; x++) {
2683
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2684
- __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2685
- __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
2686
- __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
2952
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2953
+ __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
2954
+ __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
2955
+ __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
2687
2956
  }
2688
- for (int I = 0; I<8; I++) {
2689
- for (int J = 0; J<4; J++) {
2690
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2691
- *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2957
+ for (int I = 0; I < 8 ; I++) {
2958
+ for (int J = 0; J < 4; J++) {
2959
+ *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2960
+ *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
2692
2961
  }
2693
2962
  }
2694
2963
  if (!isAblock_q4) {
2695
- auto aoffset = A+(ii*lda)+l;
2964
+ auto aoffset = A + (ii * lda) + l;
2696
2965
  for (int i = 0; i < 8; i++) {
2697
2966
  comparray[i] = 0;
2698
2967
  int ca = 0;
@@ -2703,19 +2972,99 @@ class tinyBLAS_BF16_PPC {
2703
2972
  aoffset += lda;
2704
2973
  }
2705
2974
  }
2706
- compute(&acc_0, 0, 0, comparray, vs, fin_res);
2707
- compute(&acc_1, 4, 4, comparray, vs, fin_res);
2708
- compute(&acc_2, 0, 8, comparray, vs, fin_res);
2709
- compute(&acc_3, 4, 12, comparray, vs, fin_res);
2975
+ compute(& acc_0, 0, 0, comparray, vs, fin_res);
2976
+ compute(& acc_1, 4, 4, comparray, vs, fin_res);
2977
+ compute(& acc_2, 0, 8, comparray, vs, fin_res);
2978
+ compute(& acc_3, 4, 12, comparray, vs, fin_res);
2710
2979
  }
2711
2980
  save_res(ii, jj, 0, fin_res);
2712
- save_res(ii+4, jj, 4, fin_res);
2713
- save_res(ii, jj+4, 8, fin_res);
2714
- save_res(ii+4, jj+4, 12, fin_res);
2981
+ save_res(ii + 4, jj, 4, fin_res);
2982
+ save_res(ii, jj + 4, 8, fin_res);
2983
+ save_res(ii + 4, jj + 4, 12, fin_res);
2984
+ }
2985
+
2986
+ void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
2987
+ acc_t acc[8];
2988
+ for (int i = 0; i < mc ; i += 16) {
2989
+ for (int j = 0; j < nc; j += 8) {
2990
+ int A0_base = (i / 16) * (2 * 32 * kc);
2991
+ int B0_base = (j / 8) * (32 * kc);
2992
+ for (int x = 0; x < 8; x++) {
2993
+ __builtin_mma_xxsetaccz(&acc[x]);
2994
+ }
2995
+ for (int64_t kk = 0; kk < kc; kk++) {
2996
+ int A0_block_idx = A0_base + kk * 32;
2997
+ int B0_block_idx = B0_base + kk * 32;
2998
+ int A1_block_idx = A0_block_idx + 32 * kc;
2999
+ int B1_block_idx = B0_block_idx + 32 * kc;
3000
+ vec_t * A0_block = & vec_A[A0_block_idx];
3001
+ vec_t * B0_block = & vec_B[B0_block_idx];
3002
+ vec_t * A1_block = & vec_A[A1_block_idx];
3003
+ for (int it = 0; it < 4; it++) {
3004
+ for (int x = 0; x < 4; x++) {
3005
+ __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
3006
+ __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
3007
+ __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
3008
+ __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
3009
+ __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
3010
+ __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
3011
+ __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
3012
+ __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
3013
+ }
3014
+ }
3015
+ }
3016
+ if (l == 0) {
3017
+ save_acc(& acc[0], ii + i, jj + j);
3018
+ save_acc(& acc[1], ii + i, jj + j + 4);
3019
+ save_acc(& acc[2], ii + i + 4, jj + j);
3020
+ save_acc(& acc[3], ii + i + 4, jj + j + 4);
3021
+ save_acc(& acc[4], ii + i + 8, jj + j);
3022
+ save_acc(& acc[5], ii + i + 8, jj + j + 4);
3023
+ save_acc(& acc[6], ii + i + 12, jj + j);
3024
+ save_acc(& acc[7], ii + i + 12, jj + j + 4);
3025
+ } else {
3026
+ add_save_acc(& acc[0], ii + i, jj + j);
3027
+ add_save_acc(& acc[1], ii + i, jj + j + 4);
3028
+ add_save_acc(& acc[2], ii + i + 4, jj + j);
3029
+ add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
3030
+ add_save_acc(& acc[4], ii + i + 8, jj + j);
3031
+ add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
3032
+ add_save_acc(& acc[6], ii + i + 12, jj + j);
3033
+ add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
3034
+ }
3035
+ }
3036
+ }
2715
3037
  }
2716
3038
 
2717
- template<typename TA>
2718
- void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3039
+ void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
3040
+ vec_t A_pack[mc * kc * 4];
3041
+ vec_t B_pack[nc * kc * 4];
3042
+ constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
3043
+ int64_t ytiles = m / mc;
3044
+ int64_t xtiles = n / nc;
3045
+ int64_t tiles = xtiles * ytiles;
3046
+ int64_t duty = (tiles + nth - 1) / nth;
3047
+ int64_t start = duty * ith;
3048
+ int64_t end = start + duty;
3049
+ if (end > tiles) {
3050
+ end = tiles;
3051
+ }
3052
+ for (int64_t job = start; job < end; ++job) {
3053
+ int64_t ii = (job / xtiles) * mc;
3054
+ int64_t jj = (job % xtiles) * nc;
3055
+ for (int64_t kk = 0; kk < k; kk += kc) {
3056
+ if constexpr(is_Ablock_q4) {
3057
+ packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
3058
+ } else {
3059
+ packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
3060
+ }
3061
+ packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
3062
+ KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
3063
+ }
3064
+ }
3065
+ }
3066
+
3067
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2719
3068
  int64_t ytiles = (m - m0) / RM;
2720
3069
  int64_t xtiles = (n - n0) / RN;
2721
3070
  int64_t tiles = xtiles * ytiles;
@@ -2737,32 +3086,32 @@ class tinyBLAS_BF16_PPC {
2737
3086
  vector float fin_res[4] = {0};
2738
3087
  vector float vs[4] = {0};
2739
3088
  vector float CA[4] = {0};
2740
- __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
2741
- __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
3089
+ __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
3090
+ __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
2742
3091
  for (int l = 0; l < k; l++) {
2743
- __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2744
- __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2745
- __builtin_mma_xxsetaccz(&acc_0);
3092
+ __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
3093
+ __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
3094
+ __builtin_mma_xxsetaccz(& acc_0);
2746
3095
  if (isAblock_q4) {
2747
- packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
3096
+ packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
2748
3097
  } else {
2749
- packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
3098
+ packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
2750
3099
  }
2751
- packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2752
- for(int x = 0; x < 8; x+=4) {
2753
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2754
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
2755
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
2756
- __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
3100
+ packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
3101
+ for (int x = 0; x < 8; x += 4) {
3102
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
3103
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
3104
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
3105
+ __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
2757
3106
  }
2758
- for (int I = 0; I<RM; I++) {
2759
- for (int J = 0; J<RN; J++) {
2760
- *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
3107
+ for (int I = 0; I < RM; I++) {
3108
+ for (int J = 0; J < RN; J++) {
3109
+ *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2761
3110
  }
2762
3111
  }
2763
- __builtin_mma_disassemble_acc(vec_C, &acc_0);
3112
+ __builtin_mma_disassemble_acc(vec_C, & acc_0);
2764
3113
  if (!isAblock_q4) {
2765
- auto aoffset = A+(ii*lda)+l;
3114
+ auto aoffset = A + (ii * lda) + l;
2766
3115
  for (int i = 0; i < RM; i++) {
2767
3116
  comparray[i] = 0;
2768
3117
  int ca = 0;
@@ -2783,9 +3132,21 @@ class tinyBLAS_BF16_PPC {
2783
3132
  }
2784
3133
  }
2785
3134
 
2786
- template<typename TA>
3135
+ template<int RM, int RN>
3136
+ inline void kernel(int64_t ii, int64_t jj) {
3137
+ if constexpr(RM == 4 && RN == 8) {
3138
+ KERNEL_4x8(ii,jj);
3139
+ } else if constexpr(RM == 8 && RN == 4) {
3140
+ KERNEL_8x4(ii,jj);
3141
+ } else if constexpr(RM == 8 && RN == 8) {
3142
+ KERNEL_8x8(ii,jj);
3143
+ } else {
3144
+ assert(false && "RN/RM values not supported");
3145
+ }
3146
+ }
3147
+
2787
3148
  template <int RM, int RN>
2788
- NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3149
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2789
3150
  int64_t ytiles = (m - m0) / RM;
2790
3151
  int64_t xtiles = (n - n0) / RN;
2791
3152
  int64_t tiles = xtiles * ytiles;
@@ -2797,12 +3158,20 @@ class tinyBLAS_BF16_PPC {
2797
3158
  for (int64_t job = start; job < end; ++job) {
2798
3159
  int64_t ii = m0 + job / xtiles * RM;
2799
3160
  int64_t jj = n0 + job % xtiles * RN;
2800
- this->kernel<RM, RN>(ii, jj);
3161
+ kernel<RM, RN>(ii, jj);
2801
3162
  }
2802
3163
  }
2803
-
2804
- template class tinyBLAS_Q0_PPC<block_q4_0>;
2805
- template class tinyBLAS_Q0_PPC<block_q8_0>;
3164
+ const TA * const A;
3165
+ const block_q8_0 * const B;
3166
+ float * C;
3167
+ const int64_t k;
3168
+ int64_t kc;
3169
+ const int64_t lda;
3170
+ const int64_t ldb;
3171
+ const int64_t ldc;
3172
+ const int ith;
3173
+ const int nth;
3174
+ };
2806
3175
 
2807
3176
  class tinyBLAS_PPC {
2808
3177
  public:
@@ -3418,16 +3787,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3418
3787
  return tb.matmul(m, n);
3419
3788
  }
3420
3789
  #elif defined(__MMA__)
3421
- if ((k % 8))
3422
- return false;
3423
- if(Btype == GGML_TYPE_BF16) {
3424
- tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3425
- (const ggml_bf16_t *)A, lda,
3426
- (const ggml_bf16_t *)B, ldb,
3427
- (float *)C, ldc,
3428
- params->ith, params->nth};
3429
- tb.matmul(m, n);
3430
- return true;
3790
+ if (k % 8) {
3791
+ return false;
3792
+ }
3793
+
3794
+ if (Btype == GGML_TYPE_BF16) {
3795
+ tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3796
+ (const ggml_bf16_t *)A, lda,
3797
+ (const ggml_bf16_t *)B, ldb,
3798
+ (float *)C, ldc,
3799
+ params->ith, params->nth };
3800
+
3801
+ tb.matmul(m, n);
3802
+ return true;
3431
3803
  }
3432
3804
  #elif defined(__riscv_zvfbfwma)
3433
3805
  #if LMUL == 1
@@ -3516,6 +3888,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3516
3888
  #endif
3517
3889
  return tb.matmul(m, n);
3518
3890
  }
3891
+ #elif defined(__MMA__)
3892
+ if (k % 8) {
3893
+ return false;
3894
+ }
3895
+
3896
+ if (Btype == GGML_TYPE_F16) {
3897
+ tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
3898
+ (const ggml_fp16_t *)A, lda,
3899
+ (const ggml_fp16_t *)B, ldb,
3900
+ (float *)C, ldc,
3901
+ params->ith, params->nth };
3902
+
3903
+ tb.matmul(m, n);
3904
+ return true;
3905
+ }
3519
3906
  #endif
3520
3907
  return false;
3521
3908
  }