whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC {
1541
1541
  } else if constexpr(RM == 8 && RN == 4) {
1542
1542
  KERNEL_8x4(ii,jj);
1543
1543
  } else {
1544
- static_assert(false, "RN/RM values not supported");
1544
+ assert(false && "RN/RM values not supported");
1545
1545
  }
1546
1546
  }
1547
1547
 
@@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC {
1573
1573
  const int nth;
1574
1574
  };
1575
1575
 
1576
- template <typename TA, typename TB, typename TC>
1576
+ template <typename TA>
1577
1577
  class tinyBLAS_Q0_PPC {
1578
1578
  public:
1579
1579
  tinyBLAS_Q0_PPC(int64_t k,
1580
1580
  const TA *A, int64_t lda,
1581
- const TB *B, int64_t ldb,
1582
- TC *C, int64_t ldc,
1581
+ const block_q8_0 *B, int64_t ldb,
1582
+ float *C, int64_t ldc,
1583
1583
  int ith, int nth)
1584
1584
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1585
1585
  }
@@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC {
1590
1590
 
1591
1591
  private:
1592
1592
 
1593
- template<int RM, int RN>
1594
- inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1593
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1595
1594
  for (int I = 0; I < RM; I++) {
1596
1595
  for (int J = 0; J < RN; J++) {
1597
1596
  *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
@@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC {
1611
1610
  fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1612
1611
  }
1613
1612
  }
1614
-
1615
- template<typename VA, typename VB, int size>
1616
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
1617
- int64_t i, j;
1618
- TA *aoffset = NULL;
1619
- VA *vecOffset = NULL;
1620
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1621
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1622
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1623
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1624
- VB t1, t2, t3, t4, t5, t6, t7, t8;
1613
+ /* This function processes quantized data from block_q4_0 elements.
1614
+ * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
1615
+ * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
1616
+ * Also compute the rowsum which is required to compensate the above conversion. */
1617
+ inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
1625
1618
  const vector signed char lowMask = vec_splats((signed char)0xF);
1626
1619
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1627
1620
  const vector signed char v8 = vec_splats((signed char)0x8);
1628
- aoffset = const_cast<TA*>(a);
1629
- vecOffset = vec;
1621
+ vector signed int vsum = {0};
1622
+ vector signed int vsum2 = {0};
1623
+ c[0] = vec_and(c[1], lowMask);
1624
+ c[1] = vec_sr(c[1], v4);
1625
+ c[0] = vec_sub(c[0], v8);
1626
+ c[1] = vec_sub(c[1], v8);
1627
+ vsum = vec_sum4s(c[0], vsum);
1628
+ vsum2 = vec_sum4s(c[1], vsum2);
1629
+ vsum = vec_add(vsum, vsum2);
1630
+ *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1631
+ }
1632
+
1633
+ template <typename V1, typename V2>
1634
+ inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
1630
1635
  vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1631
1636
  vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1632
1637
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1633
1638
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1634
- vector signed int vsum = {0};
1635
- vector signed int vsum2 = {0};
1639
+ V2 t1, t2, t3, t4, t5, t6, t7, t8;
1640
+ vector unsigned char xor_vector;
1641
+ uint8_t flip_vec = 0x80;
1642
+ xor_vector = vec_splats(flip_vec);
1643
+ t1 = vec_perm(s1, s2, swiz1);
1644
+ t2 = vec_perm(s1, s2, swiz2);
1645
+ t3 = vec_perm(s3, s4, swiz1);
1646
+ t4 = vec_perm(s3, s4, swiz2);
1647
+ t5 = vec_perm(t1, t3, swiz3);
1648
+ t6 = vec_perm(t1, t3, swiz4);
1649
+ t7 = vec_perm(t2, t4, swiz3);
1650
+ t8 = vec_perm(t2, t4, swiz4);
1651
+ if (flip == true) {
1652
+ t5 = vec_xor(t5, xor_vector);
1653
+ t6 = vec_xor(t6, xor_vector);
1654
+ t7 = vec_xor(t7, xor_vector);
1655
+ t8 = vec_xor(t8, xor_vector);
1656
+ }
1657
+ vec_xst(t5, 0, vecOffset);
1658
+ vec_xst(t6, 0, vecOffset+16);
1659
+ vec_xst(t7, 0, vecOffset+32);
1660
+ vec_xst(t8, 0, vecOffset+48);
1661
+ }
1636
1662
 
1663
+ template<int size>
1664
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1665
+ int64_t i, j;
1666
+ TA *aoffset = NULL;
1667
+ int8_t *vecOffset = NULL;
1668
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1669
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1670
+ vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1671
+ vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1672
+ aoffset = const_cast<TA*>(a);
1673
+ vecOffset = vec;
1637
1674
  j = (rows >> 3);
1638
1675
  if (j > 0) {
1639
1676
  do {
@@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC {
1646
1683
  aoffset7 = aoffset6 + lda;
1647
1684
  aoffset8 = aoffset7 + lda;
1648
1685
  aoffset += 8 * lda;
1649
-
1650
1686
  i = (cols >> 2);
1651
1687
  if (i > 0) {
1652
1688
  do {
1653
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1654
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1655
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1656
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1657
- c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1658
- c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1659
- c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1660
- c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1661
-
1662
- c1[0] = vec_and(c1[1], lowMask);
1663
- c1[1] = vec_sr(c1[1], v4);
1664
- c1[0] = vec_sub(c1[0], v8);
1665
- c1[1] = vec_sub(c1[1], v8);
1666
- vsum = vec_sum4s(c1[0], vsum);
1667
- vsum2 = vec_sum4s(c1[1], vsum2);
1668
- vsum = vec_add(vsum, vsum2);
1669
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1670
- vsum = vec_splats(0);
1671
- vsum2 = vec_splats(0);
1672
-
1673
- c2[0] = vec_and(c2[1], lowMask);
1674
- c2[1] = vec_sr(c2[1], v4);
1675
- c2[0] = vec_sub(c2[0], v8);
1676
- c2[1] = vec_sub(c2[1], v8);
1677
- vsum = vec_sum4s(c2[0], vsum);
1678
- vsum2 = vec_sum4s(c2[1], vsum2);
1679
- vsum = vec_add(vsum, vsum2);
1680
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1681
- vsum = vec_splats(0);
1682
- vsum2 = vec_splats(0);
1683
-
1684
- c3[0] = vec_and(c3[1], lowMask);
1685
- c3[1] = vec_sr(c3[1], v4);
1686
- c3[0] = vec_sub(c3[0], v8);
1687
- c3[1] = vec_sub(c3[1], v8);
1688
- vsum = vec_sum4s(c3[0], vsum);
1689
- vsum2 = vec_sum4s(c3[1], vsum2);
1690
- vsum = vec_add(vsum, vsum2);
1691
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1692
- vsum = vec_splats(0);
1693
- vsum2 = vec_splats(0);
1694
-
1695
- c4[0] = vec_and(c4[1], lowMask);
1696
- c4[1] = vec_sr(c4[1], v4);
1697
- c4[0] = vec_sub(c4[0], v8);
1698
- c4[1] = vec_sub(c4[1], v8);
1699
- vsum = vec_sum4s(c4[0], vsum);
1700
- vsum2 = vec_sum4s(c4[1], vsum2);
1701
- vsum = vec_add(vsum, vsum2);
1702
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1703
- vsum = vec_splats(0);
1704
- vsum2 = vec_splats(0);
1705
-
1706
- c5[0] = vec_and(c5[1], lowMask);
1707
- c5[1] = vec_sr(c5[1], v4);
1708
- c5[0] = vec_sub(c5[0], v8);
1709
- c5[1] = vec_sub(c5[1], v8);
1710
- vsum = vec_sum4s(c5[0], vsum);
1711
- vsum2 = vec_sum4s(c5[1], vsum2);
1712
- vsum = vec_add(vsum, vsum2);
1713
- comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1714
- vsum = vec_splats(0);
1715
- vsum2 = vec_splats(0);
1716
-
1717
- c6[0] = vec_and(c6[1], lowMask);
1718
- c6[1] = vec_sr(c6[1], v4);
1719
- c6[0] = vec_sub(c6[0], v8);
1720
- c6[1] = vec_sub(c6[1], v8);
1721
- vsum = vec_sum4s(c6[0], vsum);
1722
- vsum2 = vec_sum4s(c6[1], vsum2);
1723
- vsum = vec_add(vsum, vsum2);
1724
- comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1725
- vsum = vec_splats(0);
1726
- vsum2 = vec_splats(0);
1727
-
1728
- c7[0] = vec_and(c7[1], lowMask);
1729
- c7[1] = vec_sr(c7[1], v4);
1730
- c7[0] = vec_sub(c7[0], v8);
1731
- c7[1] = vec_sub(c7[1], v8);
1732
- vsum = vec_sum4s(c7[0], vsum);
1733
- vsum2 = vec_sum4s(c7[1], vsum2);
1734
- vsum = vec_add(vsum, vsum2);
1735
- comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1736
- vsum = vec_splats(0);
1737
- vsum2 = vec_splats(0);
1738
-
1739
- c8[0] = vec_and(c8[1], lowMask);
1740
- c8[1] = vec_sr(c8[1], v4);
1741
- c8[0] = vec_sub(c8[0], v8);
1742
- c8[1] = vec_sub(c8[1], v8);
1743
- vsum = vec_sum4s(c8[0], vsum);
1744
- vsum2 = vec_sum4s(c8[1], vsum2);
1745
- vsum = vec_add(vsum, vsum2);
1746
- comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1747
- vsum = vec_splats(0);
1748
- vsum2 = vec_splats(0);
1749
-
1750
- t1 = vec_perm(c1[0], c2[0], swiz1);
1751
- t2 = vec_perm(c1[0], c2[0], swiz2);
1752
- t3 = vec_perm(c3[0], c4[0], swiz1);
1753
- t4 = vec_perm(c3[0], c4[0], swiz2);
1754
- t5 = vec_perm(t1, t3, swiz3);
1755
- t6 = vec_perm(t1, t3, swiz4);
1756
- t7 = vec_perm(t2, t4, swiz3);
1757
- t8 = vec_perm(t2, t4, swiz4);
1758
- vec_xst(t5, 0, vecOffset);
1759
- vec_xst(t6, 0, vecOffset+16);
1760
- vec_xst(t7, 0, vecOffset+32);
1761
- vec_xst(t8, 0, vecOffset+48);
1762
-
1763
- t1 = vec_perm(c1[1], c2[1], swiz1);
1764
- t2 = vec_perm(c1[1], c2[1], swiz2);
1765
- t3 = vec_perm(c3[1], c4[1], swiz1);
1766
- t4 = vec_perm(c3[1], c4[1], swiz2);
1767
- t5 = vec_perm(t1, t3, swiz3);
1768
- t6 = vec_perm(t1, t3, swiz4);
1769
- t7 = vec_perm(t2, t4, swiz3);
1770
- t8 = vec_perm(t2, t4, swiz4);
1771
- vec_xst(t5, 0, vecOffset+64);
1772
- vec_xst(t6, 0, vecOffset+80);
1773
- vec_xst(t7, 0, vecOffset+96);
1774
- vec_xst(t8, 0, vecOffset+112);
1775
-
1776
- t1 = vec_perm(c5[0], c6[0], swiz1);
1777
- t2 = vec_perm(c5[0], c6[0], swiz2);
1778
- t3 = vec_perm(c7[0], c8[0], swiz1);
1779
- t4 = vec_perm(c7[0], c8[0], swiz2);
1780
- t5 = vec_perm(t1, t3, swiz3);
1781
- t6 = vec_perm(t1, t3, swiz4);
1782
- t7 = vec_perm(t2, t4, swiz3);
1783
- t8 = vec_perm(t2, t4, swiz4);
1784
- vec_xst(t5, 0, vecOffset+128);
1785
- vec_xst(t6, 0, vecOffset+144);
1786
- vec_xst(t7, 0, vecOffset+160);
1787
- vec_xst(t8, 0, vecOffset+176);
1788
-
1789
- t1 = vec_perm(c5[1], c6[1], swiz1);
1790
- t2 = vec_perm(c5[1], c6[1], swiz2);
1791
- t3 = vec_perm(c7[1], c8[1], swiz1);
1792
- t4 = vec_perm(c7[1], c8[1], swiz2);
1793
- t5 = vec_perm(t1, t3, swiz3);
1794
- t6 = vec_perm(t1, t3, swiz4);
1795
- t7 = vec_perm(t2, t4, swiz3);
1796
- t8 = vec_perm(t2, t4, swiz4);
1797
- vec_xst(t5, 0, vecOffset+192);
1798
- vec_xst(t6, 0, vecOffset+208);
1799
- vec_xst(t7, 0, vecOffset+224);
1800
- vec_xst(t8, 0, vecOffset+240);
1801
-
1689
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1690
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1691
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1692
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1693
+ c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1694
+ c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1695
+ c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1696
+ c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1697
+
1698
+ process_q4_elements(c1, &comparray[0]);
1699
+ process_q4_elements(c2, &comparray[1]);
1700
+ process_q4_elements(c3, &comparray[2]);
1701
+ process_q4_elements(c4, &comparray[3]);
1702
+ process_q4_elements(c5, &comparray[4]);
1703
+ process_q4_elements(c6, &comparray[5]);
1704
+ process_q4_elements(c7, &comparray[6]);
1705
+ process_q4_elements(c8, &comparray[7]);
1706
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1707
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1708
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
1709
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
1802
1710
  aoffset1 += lda;
1803
1711
  aoffset2 += lda;
1804
1712
  aoffset3 += lda;
@@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC {
1821
1729
  aoffset3 = aoffset2 + lda;
1822
1730
  aoffset4 = aoffset3 + lda;
1823
1731
  aoffset += 4 * lda;
1824
-
1825
1732
  i = (cols >> 2);
1826
1733
  if (i > 0) {
1827
1734
  do {
1828
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1829
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1830
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1831
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1832
-
1833
- c1[0] = vec_and(c1[1], lowMask);
1834
- c1[1] = vec_sr(c1[1], v4);
1835
- c1[0] = vec_sub(c1[0], v8);
1836
- c1[1] = vec_sub(c1[1], v8);
1837
- vsum = vec_sum4s(c1[0], vsum);
1838
- vsum2 = vec_sum4s(c1[1], vsum2);
1839
- vsum = vec_add(vsum, vsum2);
1840
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1841
- vsum = vec_splats(0);
1842
- vsum2 = vec_splats(0);
1843
-
1844
- c2[0] = vec_and(c2[1], lowMask);
1845
- c2[1] = vec_sr(c2[1], v4);
1846
- c2[0] = vec_sub(c2[0], v8);
1847
- c2[1] = vec_sub(c2[1], v8);
1848
- vsum = vec_sum4s(c2[0], vsum);
1849
- vsum2 = vec_sum4s(c2[1], vsum2);
1850
- vsum = vec_add(vsum, vsum2);
1851
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1852
- vsum = vec_splats(0);
1853
- vsum2 = vec_splats(0);
1854
-
1855
- c3[0] = vec_and(c3[1], lowMask);
1856
- c3[1] = vec_sr(c3[1], v4);
1857
- c3[0] = vec_sub(c3[0], v8);
1858
- c3[1] = vec_sub(c3[1], v8);
1859
- vsum = vec_sum4s(c3[0], vsum);
1860
- vsum2 = vec_sum4s(c3[1], vsum2);
1861
- vsum = vec_add(vsum, vsum2);
1862
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1863
- vsum = vec_splats(0);
1864
- vsum2 = vec_splats(0);
1865
-
1866
- c4[0] = vec_and(c4[1], lowMask);
1867
- c4[1] = vec_sr(c4[1], v4);
1868
- c4[0] = vec_sub(c4[0], v8);
1869
- c4[1] = vec_sub(c4[1], v8);
1870
- vsum = vec_sum4s(c4[0], vsum);
1871
- vsum2 = vec_sum4s(c4[1], vsum2);
1872
- vsum = vec_add(vsum, vsum2);
1873
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1874
- vsum = vec_splats(0);
1875
- vsum2 = vec_splats( 0);
1876
-
1877
- t1 = vec_perm(c1[0], c2[0], swiz1);
1878
- t2 = vec_perm(c1[0], c2[0], swiz2);
1879
- t3 = vec_perm(c3[0], c4[0], swiz1);
1880
- t4 = vec_perm(c3[0], c4[0], swiz2);
1881
- t5 = vec_perm(t1, t3, swiz3);
1882
- t6 = vec_perm(t1, t3, swiz4);
1883
- t7 = vec_perm(t2, t4, swiz3);
1884
- t8 = vec_perm(t2, t4, swiz4);
1885
- vec_xst(t5, 0, vecOffset);
1886
- vec_xst(t6, 0, vecOffset+16);
1887
- vec_xst(t7, 0, vecOffset+32);
1888
- vec_xst(t8, 0, vecOffset+48);
1889
-
1890
- t1 = vec_perm(c1[1], c2[1], swiz1);
1891
- t2 = vec_perm(c1[1], c2[1], swiz2);
1892
- t3 = vec_perm(c3[1], c4[1], swiz1);
1893
- t4 = vec_perm(c3[1], c4[1], swiz2);
1894
- t5 = vec_perm(t1, t3, swiz3);
1895
- t6 = vec_perm(t1, t3, swiz4);
1896
- t7 = vec_perm(t2, t4, swiz3);
1897
- t8 = vec_perm(t2, t4, swiz4);
1898
- vec_xst(t5, 0, vecOffset+64);
1899
- vec_xst(t6, 0, vecOffset+80);
1900
- vec_xst(t7, 0, vecOffset+96);
1901
- vec_xst(t8, 0, vecOffset+112);
1902
-
1735
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1736
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1737
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1738
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1739
+
1740
+ process_q4_elements(c1, &comparray[0]);
1741
+ process_q4_elements(c2, &comparray[1]);
1742
+ process_q4_elements(c3, &comparray[2]);
1743
+ process_q4_elements(c4, &comparray[3]);
1744
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1745
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1903
1746
  aoffset1 += lda;
1904
1747
  aoffset2 += lda;
1905
1748
  aoffset3 += lda;
@@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC {
1918
1761
  if (i > 0) {
1919
1762
  do {
1920
1763
  switch(rows) {
1921
- case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1922
- case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1923
- case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1764
+ case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1765
+ case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1766
+ case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1924
1767
  break;
1925
1768
  }
1926
- c1[0] = vec_and(c1[1], lowMask);
1927
- c1[1] = vec_sr(c1[1], v4);
1928
- c1[0] = vec_sub(c1[0], v8);
1929
- c1[1] = vec_sub(c1[1], v8);
1930
- vsum = vec_sum4s(c1[0], vsum);
1931
- vsum2 = vec_sum4s(c1[1], vsum2);
1932
- vsum = vec_add(vsum, vsum2);
1933
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1934
- vsum = vec_splats(0);
1935
- vsum2 = vec_splats(0);
1936
-
1937
- c2[0] = vec_and(c2[1], lowMask);
1938
- c2[1] = vec_sr(c2[1], v4);
1939
- c2[0] = vec_sub(c2[0], v8);
1940
- c2[1] = vec_sub(c2[1], v8);
1941
- vsum = vec_sum4s(c2[0], vsum);
1942
- vsum2 = vec_sum4s(c2[1], vsum2);
1943
- vsum = vec_add(vsum, vsum2);
1944
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1945
- vsum = vec_splats(0);
1946
- vsum2 = vec_splats(0);
1947
-
1948
- c3[0] = vec_and(c3[1], lowMask);
1949
- c3[1] = vec_sr(c3[1], v4);
1950
- c3[0] = vec_sub(c3[0], v8);
1951
- c3[1] = vec_sub(c3[1], v8);
1952
- vsum = vec_sum4s(c3[0], vsum);
1953
- vsum2 = vec_sum4s(c3[1], vsum2);
1954
- vsum = vec_add(vsum, vsum2);
1955
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1956
- vsum = vec_splats(0);
1957
- vsum2 = vec_splats(0);
1958
-
1959
- c4[0] = vec_and(c4[1], lowMask);
1960
- c4[1] = vec_sr(c4[1], v4);
1961
- c4[0] = vec_sub(c4[0], v8);
1962
- c4[1] = vec_sub(c4[1], v8);
1963
- vsum = vec_sum4s(c4[0], vsum);
1964
- vsum2 = vec_sum4s(c4[1], vsum2);
1965
- vsum = vec_add(vsum, vsum2);
1966
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1967
- vsum = vec_splats(0);
1968
- vsum2 = vec_splats(0);
1969
-
1970
- t1 = vec_perm(c1[0], c2[0], swiz1);
1971
- t2 = vec_perm(c1[0], c2[0], swiz2);
1972
- t3 = vec_perm(c3[0], c4[0], swiz1);
1973
- t4 = vec_perm(c3[0], c4[0], swiz2);
1974
- t5 = vec_perm(t1, t3, swiz3);
1975
- t6 = vec_perm(t1, t3, swiz4);
1976
- t7 = vec_perm(t2, t4, swiz3);
1977
- t8 = vec_perm(t2, t4, swiz4);
1978
- vec_xst(t5, 0, vecOffset);
1979
- vec_xst(t6, 0, vecOffset+16);
1980
- vec_xst(t7, 0, vecOffset+32);
1981
- vec_xst(t8, 0, vecOffset+48);
1982
-
1983
- t1 = vec_perm(c1[1], c2[1], swiz1);
1984
- t2 = vec_perm(c1[1], c2[1], swiz2);
1985
- t3 = vec_perm(c3[1], c4[1], swiz1);
1986
- t4 = vec_perm(c3[1], c4[1], swiz2);
1987
- t5 = vec_perm(t1, t3, swiz3);
1988
- t6 = vec_perm(t1, t3, swiz4);
1989
- t7 = vec_perm(t2, t4, swiz3);
1990
- t8 = vec_perm(t2, t4, swiz4);
1991
- vec_xst(t5, 0, vecOffset+64);
1992
- vec_xst(t6, 0, vecOffset+80);
1993
- vec_xst(t7, 0, vecOffset+96);
1994
- vec_xst(t8, 0, vecOffset+112);
1769
+ process_q4_elements(c1, &comparray[0]);
1770
+ process_q4_elements(c2, &comparray[1]);
1771
+ process_q4_elements(c3, &comparray[2]);
1772
+ process_q4_elements(c4, &comparray[3]);
1773
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1774
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1995
1775
  aoffset1 += lda;
1996
1776
  aoffset2 += lda;
1997
1777
  aoffset3 += lda;
@@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC {
2001
1781
  }
2002
1782
  }
2003
1783
  }
2004
-
2005
1784
  template<typename VA, typename VB>
2006
- void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1785
+ void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2007
1786
  int64_t i, j;
2008
- TB *aoffset = NULL;
1787
+ block_q8_0 *aoffset = NULL;
2009
1788
  VA *vecOffset = NULL;
2010
- TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2011
- TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2012
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2013
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
2014
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
2015
- VB t1, t2, t3, t4, t5, t6, t7, t8;
2016
- vector unsigned char xor_vector;
2017
- uint8_t flip_vec = 0x80;
2018
- xor_vector = vec_splats(flip_vec);
2019
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2020
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2021
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
2022
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
2023
-
2024
- aoffset = const_cast<TB*>(a);
1789
+ block_q8_0* aoffsets[8];
1790
+ __vector_pair arr[8];
1791
+ VB c[8][2] = {0};
1792
+ VB c1[8] = {0}; VB c2[8] = {0};
1793
+ aoffset = const_cast<block_q8_0*>(a);
2025
1794
  vecOffset = vec;
2026
1795
  j = (rows >> 3);
2027
1796
  if (j > 0) {
2028
1797
  do {
2029
- aoffset1 = aoffset;
2030
- aoffset2 = aoffset1 + lda;
2031
- aoffset3 = aoffset2 + lda;
2032
- aoffset4 = aoffset3 + lda;
2033
- aoffset5 = aoffset4 + lda;
2034
- aoffset6 = aoffset5 + lda;
2035
- aoffset7 = aoffset6 + lda;
2036
- aoffset8 = aoffset7 + lda;
1798
+ aoffsets[0] = aoffset;
1799
+ for (int it = 1; it < 8; it++)
1800
+ aoffsets[it] = aoffsets[it-1] + lda;
2037
1801
  aoffset += 8 * lda;
2038
1802
 
2039
1803
  i = (cols >> 3);
2040
1804
  if (i > 0) {
2041
1805
  do {
2042
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2043
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2044
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2045
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2046
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
2047
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
2048
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
2049
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
2050
-
2051
- __builtin_vsx_disassemble_pair(c1, &C1);
2052
- __builtin_vsx_disassemble_pair(c2, &C2);
2053
- __builtin_vsx_disassemble_pair(c3, &C3);
2054
- __builtin_vsx_disassemble_pair(c4, &C4);
2055
- __builtin_vsx_disassemble_pair(c5, &C5);
2056
- __builtin_vsx_disassemble_pair(c6, &C6);
2057
- __builtin_vsx_disassemble_pair(c7, &C7);
2058
- __builtin_vsx_disassemble_pair(c8, &C8);
2059
-
2060
- t1 = vec_perm(c1[0], c2[0], swiz1);
2061
- t2 = vec_perm(c1[0], c2[0], swiz2);
2062
- t3 = vec_perm(c3[0], c4[0], swiz1);
2063
- t4 = vec_perm(c3[0], c4[0], swiz2);
2064
- t5 = vec_perm(t1, t3, swiz3);
2065
- t6 = vec_perm(t1, t3, swiz4);
2066
- t7 = vec_perm(t2, t4, swiz3);
2067
- t8 = vec_perm(t2, t4, swiz4);
2068
- if (flip == true) {
2069
- t5 = vec_xor(t5, xor_vector);
2070
- t6 = vec_xor(t6, xor_vector);
2071
- t7 = vec_xor(t7, xor_vector);
2072
- t8 = vec_xor(t8, xor_vector);
2073
- }
2074
- vec_xst(t5, 0, vecOffset);
2075
- vec_xst(t6, 0, vecOffset+16);
2076
- vec_xst(t7, 0, vecOffset+32);
2077
- vec_xst(t8, 0, vecOffset+48);
2078
-
2079
- t1 = vec_perm(c1[1], c2[1], swiz1);
2080
- t2 = vec_perm(c1[1], c2[1], swiz2);
2081
- t3 = vec_perm(c3[1], c4[1], swiz1);
2082
- t4 = vec_perm(c3[1], c4[1], swiz2);
2083
- t5 = vec_perm(t1, t3, swiz3);
2084
- t6 = vec_perm(t1, t3, swiz4);
2085
- t7 = vec_perm(t2, t4, swiz3);
2086
- t8 = vec_perm(t2, t4, swiz4);
2087
- if (flip == true) {
2088
- t5 = vec_xor(t5, xor_vector);
2089
- t6 = vec_xor(t6, xor_vector);
2090
- t7 = vec_xor(t7, xor_vector);
2091
- t8 = vec_xor(t8, xor_vector);
1806
+ for (int it = 0; it < 8; it++) {
1807
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1808
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1809
+ c1[it] = c[it][0];
1810
+ c2[it] = c[it][1];
2092
1811
  }
2093
- vec_xst(t5, 0, vecOffset+64);
2094
- vec_xst(t6, 0, vecOffset+80);
2095
- vec_xst(t7, 0, vecOffset+96);
2096
- vec_xst(t8, 0, vecOffset+112);
2097
-
2098
- t1 = vec_perm(c5[0], c6[0], swiz1);
2099
- t2 = vec_perm(c5[0], c6[0], swiz2);
2100
- t3 = vec_perm(c7[0], c8[0], swiz1);
2101
- t4 = vec_perm(c7[0], c8[0], swiz2);
2102
- t5 = vec_perm(t1, t3, swiz3);
2103
- t6 = vec_perm(t1, t3, swiz4);
2104
- t7 = vec_perm(t2, t4, swiz3);
2105
- t8 = vec_perm(t2, t4, swiz4);
2106
- if (flip == true) {
2107
- t5 = vec_xor(t5, xor_vector);
2108
- t6 = vec_xor(t6, xor_vector);
2109
- t7 = vec_xor(t7, xor_vector);
2110
- t8 = vec_xor(t8, xor_vector);
2111
- }
2112
- vec_xst(t5, 0, vecOffset+128);
2113
- vec_xst(t6, 0, vecOffset+144);
2114
- vec_xst(t7, 0, vecOffset+160);
2115
- vec_xst(t8, 0, vecOffset+176);
2116
-
2117
- t1 = vec_perm(c5[1], c6[1], swiz1);
2118
- t2 = vec_perm(c5[1], c6[1], swiz2);
2119
- t3 = vec_perm(c7[1], c8[1], swiz1);
2120
- t4 = vec_perm(c7[1], c8[1], swiz2);
2121
- t5 = vec_perm(t1, t3, swiz3);
2122
- t6 = vec_perm(t1, t3, swiz4);
2123
- t7 = vec_perm(t2, t4, swiz3);
2124
- t8 = vec_perm(t2, t4, swiz4);
2125
- if (flip == true) {
2126
- t5 = vec_xor(t5, xor_vector);
2127
- t6 = vec_xor(t6, xor_vector);
2128
- t7 = vec_xor(t7, xor_vector);
2129
- t8 = vec_xor(t8, xor_vector);
2130
- }
2131
- vec_xst(t5, 0, vecOffset+192);
2132
- vec_xst(t6, 0, vecOffset+208);
2133
- vec_xst(t7, 0, vecOffset+224);
2134
- vec_xst(t8, 0, vecOffset+240);
2135
-
2136
- aoffset1 += lda;
2137
- aoffset2 += lda;
2138
- aoffset3 += lda;
2139
- aoffset4 += lda;
2140
- aoffset5 += lda;
2141
- aoffset6 += lda;
2142
- aoffset7 += lda;
2143
- aoffset8 += lda;
1812
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1813
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1814
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1815
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1816
+ for (int it = 0; it < 8; it++)
1817
+ aoffsets[it] += lda;
2144
1818
  vecOffset += 256;
2145
1819
  i--;
2146
1820
  } while(i > 0);
@@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC {
2150
1824
  }
2151
1825
 
2152
1826
  if (rows & 4) {
2153
- aoffset1 = aoffset;
2154
- aoffset2 = aoffset1 + lda;
2155
- aoffset3 = aoffset2 + lda;
2156
- aoffset4 = aoffset3 + lda;
2157
- aoffset += 4 * lda;
2158
-
1827
+ aoffsets[0] = aoffset;
1828
+ for (int it = 1; it < 4; it++ )
1829
+ aoffsets[it] = aoffsets[it-1] + lda;
1830
+ aoffset += 4 * lda;
2159
1831
  i = (cols >> 3);
2160
1832
  if (i > 0) {
2161
1833
  do {
2162
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2163
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2164
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2165
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2166
-
2167
- __builtin_vsx_disassemble_pair(c1, &C1);
2168
- __builtin_vsx_disassemble_pair(c2, &C2);
2169
- __builtin_vsx_disassemble_pair(c3, &C3);
2170
- __builtin_vsx_disassemble_pair(c4, &C4);
2171
-
2172
- t1 = vec_perm(c1[0], c2[0], swiz1);
2173
- t2 = vec_perm(c1[0], c2[0], swiz2);
2174
- t3 = vec_perm(c3[0], c4[0], swiz1);
2175
- t4 = vec_perm(c3[0], c4[0], swiz2);
2176
- t5 = vec_perm(t1, t3, swiz3);
2177
- t6 = vec_perm(t1, t3, swiz4);
2178
- t7 = vec_perm(t2, t4, swiz3);
2179
- t8 = vec_perm(t2, t4, swiz4);
2180
- if (flip == true) {
2181
- t5 = vec_xor(t5, xor_vector);
2182
- t6 = vec_xor(t6, xor_vector);
2183
- t7 = vec_xor(t7, xor_vector);
2184
- t8 = vec_xor(t8, xor_vector);
1834
+ for (int it = 0; it < 4; it++) {
1835
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1836
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1837
+ c1[it] = c[it][0];
1838
+ c2[it] = c[it][1];
2185
1839
  }
2186
- vec_xst(t5, 0, vecOffset);
2187
- vec_xst(t6, 0, vecOffset+16);
2188
- vec_xst(t7, 0, vecOffset+32);
2189
- vec_xst(t8, 0, vecOffset+48);
2190
-
2191
- t1 = vec_perm(c1[1], c2[1], swiz1);
2192
- t2 = vec_perm(c1[1], c2[1], swiz2);
2193
- t3 = vec_perm(c3[1], c4[1], swiz1);
2194
- t4 = vec_perm(c3[1], c4[1], swiz2);
2195
- t5 = vec_perm(t1, t3, swiz3);
2196
- t6 = vec_perm(t1, t3, swiz4);
2197
- t7 = vec_perm(t2, t4, swiz3);
2198
- t8 = vec_perm(t2, t4, swiz4);
2199
- if (flip == true) {
2200
- t5 = vec_xor(t5, xor_vector);
2201
- t6 = vec_xor(t6, xor_vector);
2202
- t7 = vec_xor(t7, xor_vector);
2203
- t8 = vec_xor(t8, xor_vector);
1840
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1841
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1842
+ for (int it = 0; it < 4; it++) {
1843
+ aoffsets[it] += lda;
2204
1844
  }
2205
- vec_xst(t5, 0, vecOffset+64);
2206
- vec_xst(t6, 0, vecOffset+80);
2207
- vec_xst(t7, 0, vecOffset+96);
2208
- vec_xst(t8, 0, vecOffset+112);
2209
-
2210
- aoffset1 += lda;
2211
- aoffset2 += lda;
2212
- aoffset3 += lda;
2213
- aoffset4 += lda;
2214
1845
  vecOffset += 128;
2215
1846
  i--;
2216
1847
  } while(i > 0);
2217
1848
  }
2218
1849
  }
1850
+
2219
1851
  if (rows & 3) {
2220
- aoffset1 = aoffset;
2221
- aoffset2 = aoffset1 + lda;
2222
- aoffset3 = aoffset2 + lda;
1852
+ aoffsets[0] = aoffset;
1853
+ for (int it = 1; it < 3; it++ )
1854
+ aoffsets[it] = aoffsets[it-1] + lda;
2223
1855
  i = (cols >> 3);
2224
1856
  if (i > 0) {
2225
1857
  do {
2226
1858
  switch(rows) {
2227
- case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2228
- __builtin_vsx_disassemble_pair(c3, &C3);
2229
- case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2230
- __builtin_vsx_disassemble_pair(c2, &C2);
2231
- case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2232
- __builtin_vsx_disassemble_pair(c1, &C1);
1859
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
1860
+ __builtin_vsx_disassemble_pair(c[2], &arr[2]);
1861
+ c1[2] = c[2][0]; c2[2] = c[2][1];
1862
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
1863
+ __builtin_vsx_disassemble_pair(c[1], &arr[1]);
1864
+ c1[1] = c[1][0]; c2[1] = c[1][1];
1865
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
1866
+ __builtin_vsx_disassemble_pair(c[0], &arr[0]);
1867
+ c1[0] = c[0][0]; c2[0] = c[0][1];
2233
1868
  break;
2234
1869
  }
2235
- t1 = vec_perm(c1[0], c2[0], swiz1);
2236
- t2 = vec_perm(c1[0], c2[0], swiz2);
2237
- t3 = vec_perm(c3[0], c4[0], swiz1);
2238
- t4 = vec_perm(c3[0], c4[0], swiz2);
2239
- t5 = vec_perm(t1, t3, swiz3);
2240
- t6 = vec_perm(t1, t3, swiz4);
2241
- t7 = vec_perm(t2, t4, swiz3);
2242
- t8 = vec_perm(t2, t4, swiz4);
2243
- if (flip == true) {
2244
- t5 = vec_xor(t5, xor_vector);
2245
- t6 = vec_xor(t6, xor_vector);
2246
- t7 = vec_xor(t7, xor_vector);
2247
- t8 = vec_xor(t8, xor_vector);
2248
- }
2249
- vec_xst(t5, 0, vecOffset);
2250
- vec_xst(t6, 0, vecOffset+16);
2251
- vec_xst(t7, 0, vecOffset+32);
2252
- vec_xst(t8, 0, vecOffset+48);
2253
-
2254
- t1 = vec_perm(c1[1], c2[1], swiz1);
2255
- t2 = vec_perm(c1[1], c2[1], swiz2);
2256
- t3 = vec_perm(c3[1], c4[1], swiz1);
2257
- t4 = vec_perm(c3[1], c4[1], swiz2);
2258
- t5 = vec_perm(t1, t3, swiz3);
2259
- t6 = vec_perm(t1, t3, swiz4);
2260
- t7 = vec_perm(t2, t4, swiz3);
2261
- t8 = vec_perm(t2, t4, swiz4);
2262
- if (flip == true) {
2263
- t5 = vec_xor(t5, xor_vector);
2264
- t6 = vec_xor(t6, xor_vector);
2265
- t7 = vec_xor(t7, xor_vector);
2266
- t8 = vec_xor(t8, xor_vector);
2267
- }
2268
- vec_xst(t5, 0, vecOffset+64);
2269
- vec_xst(t6, 0, vecOffset+80);
2270
- vec_xst(t7, 0, vecOffset+96);
2271
- vec_xst(t8, 0, vecOffset+112);
2272
-
2273
- aoffset1 += lda;
2274
- aoffset2 += lda;
2275
- aoffset3 += lda;
1870
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1871
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1872
+ for (int it = 0; it < 3; it++)
1873
+ aoffsets[it] += lda;
2276
1874
  vecOffset += 128;
2277
1875
  i--;
2278
1876
  } while(i > 0);
@@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC {
2281
1879
  }
2282
1880
 
2283
1881
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2284
- int64_t mc, nc, mp, np;
2285
- int m_rem = MIN(m - m0, 8);
2286
- int n_rem = MIN(n - n0, 8);
2287
- // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
2288
- // issues. After resolving them, below code will be enabled.
2289
- /*if (m_rem >= 16 && n_rem >= 8) {
2290
- mc = 16;
2291
- nc = 8;
2292
- gemm<16,8>(m0, m, n0, n);
2293
- } else if(m_rem >= 8 && n_rem >= 16) {
2294
- mc = 8;
2295
- nc = 16;
2296
- gemm<8,16>(m0, m, n0, n);
2297
- }*/
1882
+ int m_rem = MIN(m - m0, 16);
1883
+ int n_rem = MIN(n - n0, 16);
1884
+
1885
+ int mc = 0, nc = 0;
1886
+
2298
1887
  if (m_rem >= 8 && n_rem >= 8) {
2299
- mc = 8;
2300
- nc = 8;
2301
- gemm<8,8>(m0, m, n0, n);
1888
+ mc = 8;
1889
+ nc = 8;
1890
+ gemm<8, 8>(m0, m, n0, n);
2302
1891
  } else if (m_rem >= 4 && n_rem >= 8) {
2303
1892
  mc = 4;
2304
1893
  nc = 8;
2305
- gemm<4,8>(m0, m, n0, n);
1894
+ gemm<4, 8>(m0, m, n0, n);
2306
1895
  } else if (m_rem >= 8 && n_rem >= 4) {
2307
1896
  mc = 8;
2308
1897
  nc = 4;
2309
- gemm<8,4>(m0, m, n0, n);
1898
+ gemm<8, 4>(m0, m, n0, n);
2310
1899
  } else if (m_rem >= 4 && n_rem >= 4) {
2311
1900
  mc = 4;
2312
1901
  nc = 4;
2313
- gemm_small<4, 4>(m0, m, n0, n);
2314
- } else if ((m_rem < 4) && (n_rem > 4)) {
2315
- nc = 4;
2316
- switch(m_rem) {
2317
- case 1:
2318
- mc = 1;
2319
- gemm_small<1, 4>(m0, m, n0, n);
2320
- break;
2321
- case 2:
2322
- mc = 2;
2323
- gemm_small<2, 4>(m0, m, n0, n);
2324
- break;
2325
- case 3:
2326
- mc = 3;
2327
- gemm_small<3, 4>(m0, m, n0, n);
2328
- break;
2329
- default:
2330
- return;
2331
- }
2332
- } else if ((m_rem > 4) && (n_rem < 4)) {
2333
- mc = 4;
2334
- switch(n_rem) {
2335
- case 1:
2336
- nc = 1;
2337
- gemm_small<4, 1>(m0, m, n0, n);
2338
- break;
2339
- case 2:
2340
- nc = 2;
2341
- gemm_small<4, 2>(m0, m, n0, n);
2342
- break;
2343
- case 3:
2344
- nc = 3;
2345
- gemm_small<4, 3>(m0, m, n0, n);
2346
- break;
2347
- default:
2348
- return;
2349
- }
1902
+ gemm_small(m0, m, n0, n, mc, nc);
2350
1903
  } else {
2351
- switch((m_rem << 4) | n_rem) {
2352
- case 0x43:
2353
- mc = 4;
2354
- nc = 3;
2355
- gemm_small<4, 3>(m0, m, n0, n);
2356
- break;
2357
- case 0x42:
2358
- mc = 4;
2359
- nc = 2;
2360
- gemm_small<4, 2>(m0, m, n0, n);
2361
- break;
2362
- case 0x41:
2363
- mc = 4;
2364
- nc = 1;
2365
- gemm_small<4, 1>(m0, m, n0, n);
2366
- break;
2367
- case 0x34:
2368
- mc = 3;
2369
- nc = 4;
2370
- gemm_small<3, 4>(m0, m, n0, n);
2371
- break;
2372
- case 0x33:
2373
- mc = 3;
2374
- nc = 3;
2375
- gemm_small<3, 3>(m0, m, n0, n);
2376
- break;
2377
- case 0x32:
2378
- mc = 3;
2379
- nc = 2;
2380
- gemm_small<3, 2>(m0, m, n0, n);
2381
- break;
2382
- case 0x31:
2383
- mc = 3;
2384
- nc = 1;
2385
- gemm_small<3, 1>(m0, m, n0, n);
2386
- break;
2387
- case 0x24:
2388
- mc = 2;
2389
- nc = 4;
2390
- gemm_small<2, 4>(m0, m, n0, n);
2391
- break;
2392
- case 0x23:
2393
- mc = 2;
2394
- nc = 3;
2395
- gemm_small<2, 3>(m0, m, n0, n);
2396
- break;
2397
- case 0x22:
2398
- mc = 2;
2399
- nc = 2;
2400
- gemm_small<2, 2>(m0, m, n0, n);
2401
- break;
2402
- case 0x21:
2403
- mc = 2;
2404
- nc = 1;
2405
- gemm_small<2, 1>(m0, m, n0, n);
2406
- break;
2407
- case 0x14:
2408
- mc = 1;
2409
- nc = 4;
2410
- gemm_small<1, 4>(m0, m, n0, n);
2411
- break;
2412
- case 0x13:
2413
- mc = 1;
2414
- nc = 3;
2415
- gemm_small<1, 3>(m0, m, n0, n);
2416
- break;
2417
- case 0x12:
2418
- mc = 1;
2419
- nc = 2;
2420
- gemm_small<1, 2>(m0, m, n0, n);
2421
- break;
2422
- case 0x11:
2423
- mc = 1;
2424
- nc = 1;
2425
- gemm_small<1, 1>(m0, m, n0, n);
2426
- break;
2427
- default:
2428
- return;
2429
- }
1904
+ mc = (m_rem >= 4) ? 4 : m_rem;
1905
+ nc = (n_rem >= 4) ? 4 : n_rem;
1906
+ if (mc == 0 || nc == 0)
1907
+ return;
1908
+ gemm_small(m0, m, n0, n, mc, nc);
2430
1909
  }
2431
- mp = m0 + (m - m0) / mc * mc;
2432
- np = n0 + (n - n0) / nc * nc;
1910
+
1911
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
1912
+ int64_t np = n0 + ((n - n0) / nc) * nc;
2433
1913
  mnpack(mp, m, n0, np);
2434
1914
  mnpack(m0, m, np, n);
2435
1915
  }
2436
1916
 
1917
+
2437
1918
  void KERNEL_4x8(int64_t ii, int64_t jj) {
2438
1919
  vec_t vec_A[8], vec_B[16] = {0};
2439
1920
  acc_t acc_0, acc_1;
@@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC {
2445
1926
  __builtin_mma_xxsetaccz(&acc_0);
2446
1927
  __builtin_mma_xxsetaccz(&acc_1);
2447
1928
  if (std::is_same_v<TA, block_q4_0>) {
2448
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1929
+ packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2449
1930
  } else {
2450
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1931
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2451
1932
  }
2452
1933
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2453
1934
  for(int x = 0; x < 8; x++) {
@@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC {
2475
1956
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
2476
1957
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2477
1958
  }
2478
- save_res<4, 4>(ii, jj, 0, fin_res);
2479
- save_res<4, 4>(ii, jj+4, 4, fin_res);
1959
+ save_res(ii, jj, 0, fin_res);
1960
+ save_res(ii, jj+4, 4, fin_res);
2480
1961
  }
2481
1962
 
2482
1963
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC {
2490
1971
  __builtin_mma_xxsetaccz(&acc_0);
2491
1972
  __builtin_mma_xxsetaccz(&acc_1);
2492
1973
  if (std::is_same_v<TA, block_q4_0>) {
2493
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1974
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2494
1975
  } else {
2495
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1976
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2496
1977
  }
2497
1978
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2498
1979
  for(int x = 0; x < 8; x++) {
@@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC {
2519
2000
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2520
2001
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2521
2002
  }
2522
- save_res<4, 4>(ii, jj, 0, fin_res);
2523
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2003
+ save_res(ii, jj, 0, fin_res);
2004
+ save_res(ii+4, jj, 4, fin_res);
2524
2005
  }
2525
2006
 
2526
2007
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC {
2536
2017
  __builtin_mma_xxsetaccz(&acc_2);
2537
2018
  __builtin_mma_xxsetaccz(&acc_3);
2538
2019
  if (std::is_same_v<TA, block_q4_0>) {
2539
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2020
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2540
2021
  } else {
2541
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2022
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2542
2023
  }
2543
2024
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2544
2025
  for(int x = 0; x < 8; x++) {
@@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC {
2570
2051
  compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2571
2052
  compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2572
2053
  }
2573
- save_res<4, 4>(ii, jj, 0, fin_res);
2574
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2575
- save_res<4, 4>(ii, jj+4, 8, fin_res);
2576
- save_res<4, 4>(ii+4, jj+4, 12, fin_res);
2054
+ save_res(ii, jj, 0, fin_res);
2055
+ save_res(ii+4, jj, 4, fin_res);
2056
+ save_res(ii, jj+4, 8, fin_res);
2057
+ save_res(ii+4, jj+4, 12, fin_res);
2577
2058
  }
2578
2059
 
2579
- template<int RM, int RN>
2580
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2060
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2581
2061
  int64_t ytiles = (m - m0) / RM;
2582
2062
  int64_t xtiles = (n - n0) / RN;
2583
2063
  int64_t tiles = xtiles * ytiles;
@@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC {
2606
2086
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2607
2087
  __builtin_mma_xxsetaccz(&acc_0);
2608
2088
  if (isAblock_q4) {
2609
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2089
+ packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2610
2090
  } else {
2611
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2091
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2612
2092
  }
2613
2093
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2614
2094
  for(int x = 0; x < 8; x+=4) {
@@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC {
2641
2121
  fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2642
2122
  }
2643
2123
  }
2644
- save_res<RM, RN>(ii, jj, 0, fin_res);
2124
+ save_res(ii, jj, 0, fin_res, RM, RN);
2645
2125
  }
2646
2126
  }
2647
2127
 
@@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC {
2654
2134
  } else if constexpr(RM == 8 && RN == 8) {
2655
2135
  KERNEL_8x8(ii,jj);
2656
2136
  } else {
2657
- static_assert(false, "RN/RM values not supported");
2137
+ assert(false && "RN/RM values not supported");
2658
2138
  }
2659
2139
  }
2660
2140
 
@@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC {
2676
2156
  }
2677
2157
 
2678
2158
  const TA *const A;
2679
- const TB *const B;
2680
- TC *C;
2681
- TA *At;
2682
- TB *Bt;
2159
+ const block_q8_0 *const B;
2160
+ float *C;
2683
2161
  const int64_t k;
2684
2162
  const int64_t lda;
2685
2163
  const int64_t ldb;
@@ -2688,266 +2166,183 @@ class tinyBLAS_Q0_PPC {
2688
2166
  const int nth;
2689
2167
  };
2690
2168
 
2691
- template <typename TA, typename TB, typename TC>
2692
2169
  class tinyBLAS_PPC {
2693
2170
  public:
2694
2171
  tinyBLAS_PPC(int64_t k,
2695
- const TA *A, int64_t lda,
2696
- const TB *B, int64_t ldb,
2697
- TC *C, int64_t ldc,
2172
+ const float * A, int64_t lda,
2173
+ const float * B, int64_t ldb,
2174
+ float * C, int64_t ldc,
2698
2175
  int ith, int nth)
2699
2176
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2700
2177
  }
2701
2178
 
2702
2179
  void matmul(int64_t m, int64_t n) {
2703
- mnpack(0, m, 0, n);
2180
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2181
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2182
+ matmul_tiled(m, n, mc, nc, kc);
2183
+ } else {
2184
+ mnpack(0, m, 0, n);
2185
+ }
2704
2186
  }
2705
2187
 
2706
2188
  private:
2707
2189
 
2708
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2190
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2191
+ vec_t vec_C[4];
2192
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2193
+ for (int I = 0; I < 4; I++) {
2194
+ for (int J = 0; J < 4; J++) {
2195
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2196
+ }
2197
+ }
2198
+ }
2199
+
2200
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2201
+ vec_t vec_C[4];
2202
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2203
+ for (int I = 0; I < 4; I++) {
2204
+ for (int J = 0; J < 4; J++) {
2205
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2206
+ *c_ptr += *((float *)&vec_C[I]+J);
2207
+ }
2208
+ }
2209
+ }
2210
+
2211
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2212
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2213
+ t1 = vec_mergeh(src[0], src[1]);
2214
+ t2 = vec_mergeh(src[2], src[3]);
2215
+ t3 = vec_mergel(src[0], src[1]);
2216
+ t4 = vec_mergel(src[2], src[3]);
2217
+
2218
+ t5 = vec_xxpermdi(t1, t2, 0);
2219
+ t6 = vec_xxpermdi(t1, t2, 3);
2220
+ t7 = vec_xxpermdi(t3, t4, 0);
2221
+ t8 = vec_xxpermdi(t3, t4, 3);
2222
+
2223
+ vec_xst(t5, 0, vecOffset);
2224
+ vec_xst(t6, 0, vecOffset + 4);
2225
+ vec_xst(t7, 0, vecOffset + 8);
2226
+ vec_xst(t8, 0, vecOffset + 12);
2227
+ }
2228
+
2229
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2230
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2231
+ t1 = vec_mergeh(src[0], src[1]);
2232
+ t2 = vec_mergeh(src[2], src[3]);
2233
+ t3 = vec_mergeh(src[4], src[5]);
2234
+ t4 = vec_mergeh(src[6], src[7]);
2709
2235
 
2710
- template<typename VA>
2711
- void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
2236
+ t5 = vec_xxpermdi(t1, t2, 0);
2237
+ t6 = vec_xxpermdi(t3, t4, 0);
2238
+ t7 = vec_xxpermdi(t1, t2, 3);
2239
+ t8 = vec_xxpermdi(t3, t4, 3);
2240
+
2241
+ vec_xst(t5, 0, vecOffset);
2242
+ vec_xst(t6, 0, vecOffset + 4);
2243
+ vec_xst(t7, 0, vecOffset + 8);
2244
+ vec_xst(t8, 0, vecOffset + 12);
2245
+
2246
+ t1 = vec_mergel(src[0], src[1]);
2247
+ t2 = vec_mergel(src[2], src[3]);
2248
+ t3 = vec_mergel(src[4], src[5]);
2249
+ t4 = vec_mergel(src[6], src[7]);
2250
+
2251
+ t5 = vec_xxpermdi(t1, t2, 0);
2252
+ t6 = vec_xxpermdi(t3, t4, 0);
2253
+ t7 = vec_xxpermdi(t1, t2, 3);
2254
+ t8 = vec_xxpermdi(t3, t4, 3);
2255
+
2256
+ vec_xst(t5, 0, vecOffset + 16);
2257
+ vec_xst(t6, 0, vecOffset + 20);
2258
+ vec_xst(t7, 0, vecOffset + 24);
2259
+ vec_xst(t8, 0, vecOffset + 28);
2260
+ }
2261
+
2262
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2712
2263
  int64_t i, j;
2713
- TA *aoffset = NULL, *boffset = NULL;
2714
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2715
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2716
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2717
- VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2718
- VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2719
- VA t1, t2, t3, t4, t5, t6, t7, t8;
2720
- aoffset = const_cast<TA*>(a);
2264
+ float * aoffsets[8];
2265
+ float * aoffset = NULL, * boffset = NULL;
2266
+ __vector_pair arr[8];
2267
+ vector float c[8][2] = {0};
2268
+ vector float c1[8] = {0};
2269
+ vector float c2[8] = {0};
2270
+ aoffset = const_cast<float *>(a);
2721
2271
  boffset = vec;
2722
2272
  j = (rows >> 3);
2723
2273
  if (j > 0) {
2724
-
2725
2274
  do {
2726
- aoffset1 = aoffset;
2727
- aoffset2 = aoffset1 + lda;
2728
- aoffset3 = aoffset2 + lda;
2729
- aoffset4 = aoffset3 + lda;
2730
- aoffset5 = aoffset4 + lda;
2731
- aoffset6 = aoffset5 + lda;
2732
- aoffset7 = aoffset6 + lda;
2733
- aoffset8 = aoffset7 + lda;
2275
+ aoffsets[0] = aoffset;
2276
+ for (int it = 1; it < 8; it++)
2277
+ aoffsets[it] = aoffsets[it-1] + lda;
2734
2278
  aoffset += 8 * lda;
2735
2279
  i = (cols >> 3);
2736
2280
  if (i > 0) {
2737
2281
  do {
2738
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2739
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2740
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2741
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2742
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
2743
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
2744
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
2745
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
2746
- __builtin_vsx_disassemble_pair(c1, &C1);
2747
- __builtin_vsx_disassemble_pair(c2, &C2);
2748
- __builtin_vsx_disassemble_pair(c3, &C3);
2749
- __builtin_vsx_disassemble_pair(c4, &C4);
2750
- __builtin_vsx_disassemble_pair(c5, &C5);
2751
- __builtin_vsx_disassemble_pair(c6, &C6);
2752
- __builtin_vsx_disassemble_pair(c7, &C7);
2753
- __builtin_vsx_disassemble_pair(c8, &C8);
2754
-
2755
- t1 = vec_mergeh(c1[0], c2[0]);
2756
- t2 = vec_mergeh(c3[0], c4[0]);
2757
- t3 = vec_mergeh(c5[0], c6[0]);
2758
- t4 = vec_mergeh(c7[0], c8[0]);
2759
- t5 = vec_xxpermdi(t1, t2, 0);
2760
- t6 = vec_xxpermdi(t3, t4, 0);
2761
- t7 = vec_xxpermdi(t1, t2, 3);
2762
- t8 = vec_xxpermdi(t3, t4, 3);
2763
- vec_xst(t5, 0, boffset);
2764
- vec_xst(t6, 0, boffset+4);
2765
- vec_xst(t7, 0, boffset+8);
2766
- vec_xst(t8, 0, boffset+12);
2767
-
2768
- t1 = vec_mergel(c1[0], c2[0]);
2769
- t2 = vec_mergel(c3[0], c4[0]);
2770
- t3 = vec_mergel(c5[0], c6[0]);
2771
- t4 = vec_mergel(c7[0], c8[0]);
2772
- t5 = vec_xxpermdi(t1, t2, 0);
2773
- t6 = vec_xxpermdi(t3, t4, 0);
2774
- t7 = vec_xxpermdi(t1, t2, 3);
2775
- t8 = vec_xxpermdi(t3, t4, 3);
2776
- vec_xst(t5, 0, boffset+16);
2777
- vec_xst(t6, 0, boffset+20);
2778
- vec_xst(t7, 0, boffset+24);
2779
- vec_xst(t8, 0, boffset+28);
2780
-
2781
- t1 = vec_mergeh(c1[1], c2[1]);
2782
- t2 = vec_mergeh(c3[1], c4[1]);
2783
- t3 = vec_mergeh(c5[1], c6[1]);
2784
- t4 = vec_mergeh(c7[1], c8[1]);
2785
- t5 = vec_xxpermdi(t1, t2, 0);
2786
- t6 = vec_xxpermdi(t3, t4, 0);
2787
- t7 = vec_xxpermdi(t1, t2, 3);
2788
- t8 = vec_xxpermdi(t3, t4, 3);
2789
- vec_xst(t5, 0, boffset+32);
2790
- vec_xst(t6, 0, boffset+36);
2791
- vec_xst(t7, 0, boffset+40);
2792
- vec_xst(t8, 0, boffset+44);
2793
-
2794
- t1 = vec_mergel(c1[1], c2[1]);
2795
- t2 = vec_mergel(c3[1], c4[1]);
2796
- t3 = vec_mergel(c5[1], c6[1]);
2797
- t4 = vec_mergel(c7[1], c8[1]);
2798
- t5 = vec_xxpermdi(t1, t2, 0);
2799
- t6 = vec_xxpermdi(t3, t4, 0);
2800
- t7 = vec_xxpermdi(t1, t2, 3);
2801
- t8 = vec_xxpermdi(t3, t4, 3);
2802
- vec_xst(t5, 0, boffset+48);
2803
- vec_xst(t6, 0, boffset+52);
2804
- vec_xst(t7, 0, boffset+56);
2805
- vec_xst(t8, 0, boffset+60);
2806
-
2807
- aoffset1 += 8*lda;
2808
- aoffset2 += 8*lda;
2809
- aoffset3 += 8*lda;
2810
- aoffset4 += 8*lda;
2282
+ for (int it = 0; it < 8; it++) {
2283
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2284
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2285
+ c1[it] = c[it][0];
2286
+ c2[it] = c[it][1];
2287
+ }
2288
+
2289
+ vector_permute_store_8(c1, boffset);
2290
+ vector_permute_store_8(c2, boffset + 32);
2811
2291
  boffset += 64;
2812
2292
  i--;
2293
+ if (i > 0) {
2294
+ for (int it = 0; it < 8; it++) {
2295
+ aoffsets[it] = aoffsets[it] + 8;
2296
+ }
2297
+ }
2813
2298
  } while(i > 0);
2814
2299
  }
2815
2300
  if (cols & 4) {
2816
- c1[0] = vec_xl(0, aoffset1);
2817
- c2[0] = vec_xl(0, aoffset2);
2818
- c3[0] = vec_xl(0, aoffset3);
2819
- c4[0] = vec_xl(0, aoffset4);
2820
- c5[0] = vec_xl(0, aoffset5);
2821
- c6[0] = vec_xl(0, aoffset6);
2822
- c7[0] = vec_xl(0, aoffset7);
2823
- c8[0] = vec_xl(0, aoffset8);
2824
-
2825
- t1 = vec_mergeh(c1[0], c2[0]);
2826
- t2 = vec_mergeh(c3[0], c4[0]);
2827
- t3 = vec_mergeh(c5[0], c6[0]);
2828
- t4 = vec_mergeh(c7[0], c8[0]);
2829
- t5 = vec_xxpermdi(t1, t2, 0);
2830
- t6 = vec_xxpermdi(t3, t4, 0);
2831
- t7 = vec_xxpermdi(t1, t2, 3);
2832
- t8 = vec_xxpermdi(t3, t4, 3);
2833
- vec_xst(t5, 0, boffset);
2834
- vec_xst(t6, 0, boffset+4);
2835
- vec_xst(t7, 0, boffset+8);
2836
- vec_xst(t8, 0, boffset+12);
2837
-
2838
- t1 = vec_mergel(c1[0], c2[0]);
2839
- t2 = vec_mergel(c3[0], c4[0]);
2840
- t3 = vec_mergel(c5[0], c6[0]);
2841
- t4 = vec_mergel(c7[0], c8[0]);
2842
- t5 = vec_xxpermdi(t1, t2, 0);
2843
- t6 = vec_xxpermdi(t3, t4, 0);
2844
- t7 = vec_xxpermdi(t1, t2, 3);
2845
- t8 = vec_xxpermdi(t3, t4, 3);
2846
- vec_xst(t5, 0, boffset+16);
2847
- vec_xst(t6, 0, boffset+20);
2848
- vec_xst(t7, 0, boffset+24);
2849
- vec_xst(t8, 0, boffset+28);
2301
+ for (int it = 0; it < 8 ; it++)
2302
+ c1[it] = vec_xl(0, aoffsets[it]);
2303
+ vector_permute_store_8(c1, boffset);
2850
2304
  }
2851
2305
  j--;
2852
2306
  } while(j > 0);
2853
2307
  }
2854
2308
 
2855
2309
  if (rows & 4) {
2856
- aoffset1 = aoffset;
2857
- aoffset2 = aoffset1 + lda;
2858
- aoffset3 = aoffset2 + lda;
2859
- aoffset4 = aoffset3 + lda;
2310
+ aoffsets[0] = aoffset;
2311
+ for (int it = 1; it < 4; it++)
2312
+ aoffsets[it] = aoffsets[it-1] + lda;
2860
2313
  aoffset += 4 * lda;
2861
2314
  i = (cols >> 3);
2862
2315
  if (i > 0) {
2863
2316
  do {
2864
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2865
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2866
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2867
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2868
- __builtin_vsx_disassemble_pair(c1, &C1);
2869
- __builtin_vsx_disassemble_pair(c2, &C2);
2870
- __builtin_vsx_disassemble_pair(c3, &C3);
2871
- __builtin_vsx_disassemble_pair(c4, &C4);
2872
-
2873
- t1 = vec_mergeh(c1[0], c2[0]);
2874
- t2 = vec_mergeh(c3[0], c4[0]);
2875
- t3 = vec_mergel(c1[0], c2[0]);
2876
- t4 = vec_mergel(c3[0], c4[0]);
2877
- t5 = vec_xxpermdi(t1, t2, 0);
2878
- t6 = vec_xxpermdi(t1, t2, 3);
2879
- t7 = vec_xxpermdi(t3, t4, 0);
2880
- t8 = vec_xxpermdi(t3, t4, 3);
2881
- vec_xst(t5, 0, boffset);
2882
- vec_xst(t6, 0, boffset+4);
2883
- vec_xst(t7, 0, boffset+8);
2884
- vec_xst(t8, 0, boffset+12);
2885
-
2886
- t1 = vec_mergeh(c1[1], c2[1]);
2887
- t2 = vec_mergeh(c3[1], c4[1]);
2888
- t3 = vec_mergel(c1[1], c2[1]);
2889
- t4 = vec_mergel(c3[1], c4[1]);
2890
- t5 = vec_xxpermdi(t1, t2, 0);
2891
- t6 = vec_xxpermdi(t1, t2, 3);
2892
- t7 = vec_xxpermdi(t3, t4, 0);
2893
- t8 = vec_xxpermdi(t3, t4, 3);
2894
- vec_xst(t5, 0, boffset+16);
2895
- vec_xst(t6, 0, boffset+20);
2896
- vec_xst(t7, 0, boffset+24);
2897
- vec_xst(t8, 0, boffset+28);
2898
-
2899
- aoffset1 += 8*lda;
2900
- aoffset2 += 8*lda;
2901
- aoffset3 += 8*lda;
2902
- aoffset4 += 8*lda;
2317
+ for (int it = 0; it < 4; it++) {
2318
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2319
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2320
+ c1[it] = c[it][0];
2321
+ c2[it] = c[it][1];
2322
+ }
2323
+ vector_permute_store_4(c1, boffset);
2324
+ vector_permute_store_4(c2, boffset + 16);
2325
+ for (int it = 0; it < 4; it++)
2326
+ aoffsets[it] += 8 * lda;
2903
2327
  boffset += 32;
2904
2328
  i--;
2905
2329
  } while(i > 0);
2906
2330
  }
2907
2331
 
2908
2332
  if (cols & 4) {
2909
- c1[0] = vec_xl(0, aoffset1);
2910
- c2[0] = vec_xl(0, aoffset2);
2911
- c3[0] = vec_xl(0, aoffset3);
2912
- c4[0] = vec_xl(0, aoffset4);
2913
-
2914
- t1 = vec_mergeh(c1[0], c2[0]);
2915
- t2 = vec_mergeh(c3[0], c4[0]);
2916
- t3 = vec_xxpermdi(t1, t2, 0);
2917
- t4 = vec_xxpermdi(t1, t2, 3);
2918
- vec_xst(t3, 0, boffset);
2919
- vec_xst(t4, 0, boffset+4);
2920
-
2921
- t1 = vec_mergel(c1[0], c2[0]);
2922
- t2 = vec_mergel(c3[0], c4[0]);
2923
- t3 = vec_xxpermdi(t1, t2, 0);
2924
- t4 = vec_xxpermdi(t1, t2, 3);
2925
- vec_xst(t3, 0, boffset+8);
2926
- vec_xst(t4, 0, boffset+12);
2333
+ for (int it = 0; it < 4; it++)
2334
+ c1[it] = vec_xl(0, aoffsets[it]);
2335
+ vector_permute_store_4(c1, boffset);
2927
2336
  }
2928
2337
  }
2929
2338
  if (rows & 3) {
2930
- aoffset1 = aoffset;
2931
- aoffset2 = aoffset1 + lda;
2932
- aoffset3 = aoffset2 + lda;
2339
+ aoffsets[0] = aoffset;
2340
+ for (int it = 1; it < 3; it++)
2341
+ aoffsets[it] = aoffsets[it-1] + lda;
2933
2342
  if (cols & 4) {
2934
- c1[0] = vec_xl(0, aoffset1);
2935
- c2[0] = vec_xl(0, aoffset2);
2936
- c3[0] = vec_xl(0, aoffset3);
2937
-
2938
- t1 = vec_mergeh(c1[0], c2[0]);
2939
- t2 = vec_mergeh(c3[0], c4[0]);
2940
- t3 = vec_xxpermdi(t1, t2, 0);
2941
- t4 = vec_xxpermdi(t1, t2, 3);
2942
- vec_xst(t3, 0, boffset);
2943
- vec_xst(t4, 0, boffset+4);
2944
-
2945
- t1 = vec_mergel(c1[0], c2[0]);
2946
- t2 = vec_mergel(c3[0], c4[0]);
2947
- t3 = vec_xxpermdi(t1, t2, 0);
2948
- t4 = vec_xxpermdi(t1, t2, 3);
2949
- vec_xst(t3, 0, boffset+8);
2950
- vec_xst(t4, 0, boffset+12);
2343
+ for (int it = 0; it < 3; it++)
2344
+ c1[it] = vec_xl(0, aoffsets[it]);
2345
+ vector_permute_store_4(c1, boffset);
2951
2346
  }
2952
2347
  }
2953
2348
  }
@@ -2956,15 +2351,15 @@ class tinyBLAS_PPC {
2956
2351
  vec_t vec_A[4], vec_B[4], vec_C[4];
2957
2352
  acc_t acc_0;
2958
2353
  __builtin_mma_xxsetaccz(&acc_0);
2959
- for (int l = 0; l < k; l+=4) {
2960
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2961
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2354
+ for (int l = 0; l < k; l += 4) {
2355
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2356
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2962
2357
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2963
2358
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2964
2359
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2965
2360
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2966
2361
  }
2967
- SAVE_ACC(&acc_0, ii, jj);
2362
+ save_acc(&acc_0, ii, jj);
2968
2363
  }
2969
2364
 
2970
2365
  void KERNEL_4x8(int64_t ii, int64_t jj) {
@@ -2972,9 +2367,9 @@ class tinyBLAS_PPC {
2972
2367
  acc_t acc_0, acc_1;
2973
2368
  __builtin_mma_xxsetaccz(&acc_0);
2974
2369
  __builtin_mma_xxsetaccz(&acc_1);
2975
- for (int64_t l = 0; l < k; l+=4) {
2976
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2977
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
2370
+ for (int64_t l = 0; l < k; l += 4) {
2371
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2372
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
2978
2373
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2979
2374
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2980
2375
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2984,8 +2379,8 @@ class tinyBLAS_PPC {
2984
2379
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
2985
2380
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
2986
2381
  }
2987
- SAVE_ACC(&acc_0, ii, jj);
2988
- SAVE_ACC(&acc_1, ii, jj+4);
2382
+ save_acc(&acc_0, ii, jj);
2383
+ save_acc(&acc_1, ii, jj + 4);
2989
2384
  }
2990
2385
 
2991
2386
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2993,9 +2388,9 @@ class tinyBLAS_PPC {
2993
2388
  acc_t acc_0, acc_1;
2994
2389
  __builtin_mma_xxsetaccz(&acc_0);
2995
2390
  __builtin_mma_xxsetaccz(&acc_1);
2996
- for (int64_t l = 0; l < k; l+=4) {
2997
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2998
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2391
+ for (int64_t l = 0; l < k; l += 4) {
2392
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
2393
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2999
2394
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
3000
2395
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
3001
2396
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -3005,8 +2400,8 @@ class tinyBLAS_PPC {
3005
2400
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
3006
2401
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
3007
2402
  }
3008
- SAVE_ACC(&acc_0, ii, jj);
3009
- SAVE_ACC(&acc_1, ii+4, jj);
2403
+ save_acc(&acc_0, ii, jj);
2404
+ save_acc(&acc_1, ii + 4, jj);
3010
2405
  }
3011
2406
 
3012
2407
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -3017,173 +2412,132 @@ class tinyBLAS_PPC {
3017
2412
  __builtin_mma_xxsetaccz(&acc_2);
3018
2413
  __builtin_mma_xxsetaccz(&acc_3);
3019
2414
  for (int l = 0; l < k; l+=8) {
3020
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
3021
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
2415
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
2416
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
3022
2417
  for(int x = 0; x < 16; x+=2) {
3023
2418
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
3024
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
3025
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
3026
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
2419
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
2420
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
2421
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
2422
+ }
2423
+ }
2424
+ save_acc(&acc_0, ii, jj);
2425
+ save_acc(&acc_1, ii, jj + 4);
2426
+ save_acc(&acc_2, ii + 4, jj);
2427
+ save_acc(&acc_3, ii + 4, jj + 4);
2428
+ }
2429
+
2430
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
2431
+ for (int x = 0; x < 16; x += 2) {
2432
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
2433
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
2434
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
2435
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
2436
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
2437
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
2438
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
2439
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
2440
+ }
2441
+ }
2442
+
2443
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
2444
+ for (int64_t i = 0; i < mc; i += 16) {
2445
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
2446
+ for (int64_t j = 0; j < nc; j += 8) {
2447
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
2448
+ acc_t acc[8];
2449
+ vec_t A0_block[16]; vec_t A1_block[16];
2450
+ for (int x = 0; x < 8; x++)
2451
+ __builtin_mma_xxsetaccz(&acc[x]);
2452
+ for (int64_t l = 0; l < kc; l += 8) {
2453
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
2454
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
2455
+ int B_block_idx = B_base_addr + (l / 8) * 16;
2456
+ vec_t* A0_block = &vec_A[A0_block_idx];
2457
+ vec_t* A1_block = &vec_A[A1_block_idx];
2458
+ vec_t* B_block = &vec_B[B_block_idx];
2459
+ MMA_16x8(A0_block, A1_block, B_block, acc);
2460
+ }
2461
+ if (kk == 0) {
2462
+ save_acc(&acc[0], ii + i, jj + j);
2463
+ save_acc(&acc[1], ii + i, jj + j + 4);
2464
+ save_acc(&acc[2], ii + i + 4, jj + j);
2465
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
2466
+ save_acc(&acc[4], ii + i + 8, jj + j);
2467
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
2468
+ save_acc(&acc[6], ii + i + 12, jj + j);
2469
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
2470
+ } else {
2471
+ add_save_acc(&acc[0], ii + i, jj + j);
2472
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
2473
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
2474
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
2475
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
2476
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
2477
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
2478
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
2479
+ }
2480
+ }
2481
+ }
2482
+ }
2483
+
2484
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2485
+ int64_t ytiles = m / mc;
2486
+ int64_t xtiles = n / nc;
2487
+ int64_t tiles = xtiles * ytiles;
2488
+ int64_t duty = (tiles + nth - 1) / nth;
2489
+ int64_t start = duty * ith;
2490
+ int64_t end = start + duty;
2491
+ if (end > tiles) {
2492
+ end = tiles;
2493
+ }
2494
+ for (int64_t job = start; job < end; ++job) {
2495
+ int64_t ii = (job / xtiles) * mc;
2496
+ int64_t jj = (job % xtiles) * nc;
2497
+ for (int64_t kk = 0; kk < k; kk += kc) {
2498
+ vec_t A_pack[kc * mc / 4];
2499
+ vec_t B_pack[kc * nc / 4];
2500
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
2501
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
2502
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
3027
2503
  }
3028
2504
  }
3029
- SAVE_ACC(&acc_0, ii, jj);
3030
- SAVE_ACC(&acc_1, ii, jj+4);
3031
- SAVE_ACC(&acc_2, ii+4, jj);
3032
- SAVE_ACC(&acc_3, ii+4, jj+4);
3033
2505
  }
3034
2506
 
3035
2507
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3036
- int64_t mc, nc, mp, np;
3037
- int m_rem = MIN(m - m0, 16);
3038
- int n_rem = MIN(n - n0, 16);
3039
- if (m_rem >= 16 && n_rem >= 8) {
3040
- mc = 8;
3041
- nc = 8;
3042
- gemm<8,8>(m0, m, n0, n);
3043
- } else if(m_rem >= 8 && n_rem >= 16) {
3044
- mc = 8;
3045
- nc = 8;
3046
- gemm<8,8>(m0, m, n0, n);
3047
- } else if (m_rem >= 8 && n_rem >= 8) {
2508
+ int m_rem = MIN(m - m0, 8);
2509
+ int n_rem = MIN(n - n0, 8);
2510
+ int mc = 0, nc = 0;
2511
+ if (m_rem >= 8 && n_rem >= 8) {
3048
2512
  mc = 8;
3049
2513
  nc = 8;
3050
- gemm<8,8>(m0, m, n0, n);
2514
+ gemm<8, 8>(m0, m, n0, n);
3051
2515
  } else if (m_rem >= 4 && n_rem >= 8) {
3052
2516
  mc = 4;
3053
2517
  nc = 8;
3054
- gemm<4,8>(m0, m, n0, n);
2518
+ gemm<4, 8>(m0, m, n0, n);
3055
2519
  } else if (m_rem >= 8 && n_rem >= 4) {
3056
2520
  mc = 8;
3057
2521
  nc = 4;
3058
- gemm<8,4>(m0, m, n0, n);
2522
+ gemm<8, 4>(m0, m, n0, n);
3059
2523
  } else if (m_rem >= 4 && n_rem >= 4) {
3060
2524
  mc = 4;
3061
2525
  nc = 4;
3062
- gemm<4,4>(m0, m, n0, n);
3063
- } else if ((m_rem < 4) && (n_rem > 4)) {
3064
- nc = 4;
3065
- switch(m_rem) {
3066
- case 1:
3067
- mc = 1;
3068
- gemm_small(m0, m, n0, n, mc, nc);
3069
- break;
3070
- case 2:
3071
- mc = 2;
3072
- gemm_small(m0, m, n0, n, mc, nc);
3073
- break;
3074
- case 3:
3075
- mc = 3;
3076
- gemm_small(m0, m, n0, n, mc, nc);
3077
- break;
3078
- default:
3079
- return;
3080
- }
3081
- } else if ((m_rem > 4) && (n_rem < 4)) {
3082
- mc = 4;
3083
- switch(n_rem) {
3084
- case 1:
3085
- nc = 1;
3086
- gemm_small(m0, m, n0, n, mc, nc);
3087
- break;
3088
- case 2:
3089
- nc = 2;
3090
- gemm_small(m0, m, n0, n, mc, nc);
3091
- break;
3092
- case 3:
3093
- nc = 3;
3094
- gemm_small(m0, m, n0, n, mc, nc);
3095
- break;
3096
- default:
3097
- return;
3098
- }
2526
+ gemm<4, 4>(m0, m, n0, n);
3099
2527
  } else {
3100
- switch((m_rem << 4) | n_rem) {
3101
- case 0x43:
3102
- mc = 4;
3103
- nc = 3;
3104
- gemm_small(m0, m, n0, n, mc, nc);
3105
- break;
3106
- case 0x42:
3107
- mc = 4;
3108
- nc = 2;
3109
- gemm_small(m0, m, n0, n, mc, nc);
3110
- break;
3111
- case 0x41:
3112
- mc = 4;
3113
- nc = 1;
3114
- gemm_small(m0, m, n0, n, mc, nc);
3115
- break;
3116
- case 0x34:
3117
- mc = 3;
3118
- nc = 4;
3119
- gemm_small(m0, m, n0, n, mc, nc);
3120
- break;
3121
- case 0x33:
3122
- mc = 3;
3123
- nc = 3;
3124
- gemm_small(m0, m, n0, n, mc, nc);
3125
- break;
3126
- case 0x32:
3127
- mc = 3;
3128
- nc = 2;
3129
- gemm_small(m0, m, n0, n, mc, nc);
3130
- break;
3131
- case 0x31:
3132
- mc = 3;
3133
- nc = 1;
3134
- gemm_small(m0, m, n0, n, mc, nc);
3135
- break;
3136
- case 0x24:
3137
- mc = 2;
3138
- nc = 4;
3139
- gemm_small(m0, m, n0, n, mc, nc);
3140
- break;
3141
- case 0x23:
3142
- mc = 2;
3143
- nc = 3;
3144
- gemm_small(m0, m, n0, n, mc, nc);
3145
- break;
3146
- case 0x22:
3147
- mc = 2;
3148
- nc = 2;
3149
- gemm_small(m0, m, n0, n, mc, nc);
3150
- break;
3151
- case 0x21:
3152
- mc = 2;
3153
- nc = 1;
3154
- gemm_small(m0, m, n0, n, mc, nc);
3155
- break;
3156
- case 0x14:
3157
- mc = 1;
3158
- nc = 4;
3159
- gemm_small(m0, m, n0, n, mc, nc);
3160
- break;
3161
- case 0x13:
3162
- mc = 1;
3163
- nc = 3;
3164
- gemm_small(m0, m, n0, n, mc, nc);
3165
- break;
3166
- case 0x12:
3167
- mc = 1;
3168
- nc = 2;
3169
- gemm_small(m0, m, n0, n, mc, nc);
3170
- break;
3171
- case 0x11:
3172
- mc = 1;
3173
- nc = 1;
3174
- gemm_small(m0, m, n0, n, mc, nc);
3175
- break;
3176
- default:
3177
- return;
3178
- }
2528
+ mc = (m_rem >= 4) ? 4 : m_rem;
2529
+ nc = (n_rem >= 4) ? 4 : n_rem;
2530
+ if (mc == 0 || nc == 0)
2531
+ return;
2532
+ gemm_small(m0, m, n0, n, mc, nc);
3179
2533
  }
3180
- mp = m0 + (m - m0) / mc * mc;
3181
- np = n0 + (n - n0) / nc * nc;
2534
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
2535
+ int64_t np = n0 + ((n - n0) / nc) * nc;
3182
2536
  mnpack(mp, m, n0, np);
3183
2537
  mnpack(m0, m, np, n);
3184
2538
  }
3185
2539
 
3186
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2540
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3187
2541
  int64_t ytiles = (m - m0) / RM;
3188
2542
  int64_t xtiles = (n - n0) / RN;
3189
2543
  int64_t tiles = xtiles * ytiles;
@@ -3198,30 +2552,30 @@ class tinyBLAS_PPC {
3198
2552
  vec_t vec_C[4];
3199
2553
  acc_t acc_0;
3200
2554
  __builtin_mma_xxsetaccz(&acc_0);
3201
- vec_t vec_A[4] {0}, vec_B[4] = {0};
3202
- for (int l=0; l<k; l+=4) {
2555
+ vec_t vec_A[4] = {0}, vec_B[4] = {0};
2556
+ for (int l = 0; l < k; l += 4) {
3203
2557
  /* 'GEMV Forwarding' concept is used in first two conditional loops.
3204
2558
  * when one of the matrix has a single row/column, the elements are
3205
2559
  * broadcasted, instead of using packing routine to prepack the
3206
2560
  * matrix elements.
3207
2561
  */
3208
2562
  if (RM == 1) {
3209
- TA* a = const_cast<TA*>(A+(ii)*lda+l);
3210
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2563
+ float * a = const_cast<float *>(A + (ii) * lda + l);
2564
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3211
2565
  vec_A[0] = (vec_t)vec_xl(0,a);
3212
- vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
3213
- vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
3214
- vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2566
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
2567
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
2568
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
3215
2569
  } else if (RN == 1) {
3216
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3217
- TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2570
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2571
+ float * b = const_cast<float *>(B + (jj) * ldb + l);
3218
2572
  vec_B[0] = (vec_t)vec_xl(0,b);
3219
- vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3220
- vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3221
- vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
2573
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
2574
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
2575
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
3222
2576
  } else {
3223
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3224
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2577
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2578
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3225
2579
  }
3226
2580
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3227
2581
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -3231,12 +2585,27 @@ class tinyBLAS_PPC {
3231
2585
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
3232
2586
  for (int I = 0; I < RM; I++) {
3233
2587
  for (int J = 0; J < RN; J++) {
3234
- *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2588
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
3235
2589
  }
3236
2590
  }
3237
2591
  }
3238
2592
  }
3239
2593
 
2594
+ template<int RM, int RN>
2595
+ inline void kernel(int64_t ii, int64_t jj) {
2596
+ if constexpr(RM == 4 && RN == 4) {
2597
+ KERNEL_4x4(ii, jj);
2598
+ } else if constexpr(RM == 4 && RN == 8) {
2599
+ KERNEL_4x8(ii, jj);
2600
+ } else if constexpr(RM == 8 && RN == 4) {
2601
+ KERNEL_8x4(ii, jj);
2602
+ } else if constexpr(RM == 8 && RN == 8) {
2603
+ KERNEL_8x8(ii, jj);
2604
+ } else {
2605
+ static_assert(false, "RN/RM values not supported");
2606
+ }
2607
+ }
2608
+
3240
2609
  template <int RM, int RN>
3241
2610
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3242
2611
  int64_t ytiles = (m - m0) / RM;
@@ -3245,29 +2614,18 @@ class tinyBLAS_PPC {
3245
2614
  int64_t duty = (tiles + nth - 1) / nth;
3246
2615
  int64_t start = duty * ith;
3247
2616
  int64_t end = start + duty;
3248
- if (RM == 4 && RN == 4) {
3249
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
3250
- } else if (RM == 4 && RN == 8) {
3251
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
3252
- } else if (RM == 8 && RN == 4) {
3253
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
3254
- } else if (RM == 8 && RN == 8) {
3255
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
3256
- }
3257
2617
  if (end > tiles)
3258
2618
  end = tiles;
3259
2619
  for (int64_t job = start; job < end; ++job) {
3260
2620
  int64_t ii = m0 + job / xtiles * RM;
3261
2621
  int64_t jj = n0 + job % xtiles * RN;
3262
- (this->*kernel)(ii, jj);
2622
+ kernel<RM, RN>(ii, jj);
3263
2623
  }
3264
2624
  }
3265
2625
 
3266
- const TA *const A;
3267
- const TB *const B;
3268
- TC *C;
3269
- TA *At;
3270
- TB *Bt;
2626
+ const float * const A;
2627
+ const float * const B;
2628
+ float * C;
3271
2629
  const int64_t k;
3272
2630
  const int64_t lda;
3273
2631
  const int64_t ldb;
@@ -3366,7 +2724,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3366
2724
  #elif defined(__MMA__)
3367
2725
  if (k % 8)
3368
2726
  return false;
3369
- tinyBLAS_PPC<float, float, float> tb{
2727
+ tinyBLAS_PPC tb{
3370
2728
  k, (const float *)A, lda,
3371
2729
  (const float *)B, ldb,
3372
2730
  (float *)C, ldc,
@@ -3493,7 +2851,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3493
2851
  return false;
3494
2852
  if (m < 8 && m != 4)
3495
2853
  return false;
3496
- tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
2854
+ tinyBLAS_Q0_PPC<block_q8_0> tb{
3497
2855
  k, (const block_q8_0 *)A, lda,
3498
2856
  (const block_q8_0 *)B, ldb,
3499
2857
  (float *)C, ldc,
@@ -3530,7 +2888,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3530
2888
  return false;
3531
2889
  if (m < 8 && m != 4)
3532
2890
  return false;
3533
- tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
2891
+ tinyBLAS_Q0_PPC<block_q4_0> tb{
3534
2892
  k, (const block_q4_0 *)A, lda,
3535
2893
  (const block_q8_0 *)B, ldb,
3536
2894
  (float *)C, ldc,