whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -52,8 +52,8 @@
52
52
  #include "ggml-impl.h"
53
53
  #include "ggml-cpu-impl.h"
54
54
  #include "ggml-quants.h"
55
+ #include "simd-mappings.h"
55
56
 
56
- #include <atomic>
57
57
  #include <array>
58
58
  #include <type_traits>
59
59
 
@@ -63,7 +63,7 @@
63
63
  #define NOINLINE __attribute__((__noinline__))
64
64
  #endif
65
65
 
66
- #if defined(__ARM_NEON) || defined(__AVX512F__)
66
+ #if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
67
67
  #define VECTOR_REGISTERS 32
68
68
  #else
69
69
  #define VECTOR_REGISTERS 16
@@ -74,7 +74,7 @@
74
74
  namespace {
75
75
 
76
76
  inline float unhalf(ggml_fp16_t d) {
77
- return GGML_FP16_TO_FP32(d);
77
+ return GGML_CPU_FP16_TO_FP32(d);
78
78
  }
79
79
 
80
80
  ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -110,6 +110,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
110
110
  inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
111
111
  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
112
112
 
113
+ #if defined(__VXE__) || defined(__VXE2__)
114
+ inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
115
+ inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
116
+ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
117
+ #endif
118
+
113
119
  #if defined(__MMA__)
114
120
  typedef vector unsigned char vec_t;
115
121
  typedef __vector_quad acc_t;
@@ -163,6 +169,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
163
169
  #endif
164
170
  #endif
165
171
 
172
+ #if defined(__VXE__) || defined(__VXE2__)
173
+ template <>
174
+ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
175
+ return vec_madd(a, b, c);
176
+ }
177
+ #endif
178
+
166
179
  ////////////////////////////////////////////////////////////////////////////////////////////////////
167
180
  // VECTORIZED HORIZONTAL SUM
168
181
 
@@ -179,6 +192,13 @@ inline float hsum(float16x8_t x) {
179
192
  }
180
193
  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
181
194
 
195
+ #if defined(__VXE__) || defined(__VXE2__)
196
+ inline float hsum(float32x4_t x) {
197
+ float32x4_t tmp = x + vec_reve(x);
198
+ return tmp[0] + tmp[1];
199
+ }
200
+ #endif
201
+
182
202
  #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
183
203
  inline float hsum(__m128 x) {
184
204
  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
@@ -228,6 +248,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
228
248
  #endif // _MSC_VER
229
249
  #endif // __ARM_NEON
230
250
 
251
+ #if defined(__VXE__) || defined(__VXE2__)
252
+ template <> inline float32x4_t load(const ggml_fp16_t * p) {
253
+ float tmp[4];
254
+
255
+ for (int i = 0; i < 4; i++) {
256
+ tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);
257
+ }
258
+
259
+ return vec_xl(0, (const float *)(tmp));
260
+ }
261
+ template <> inline float32x4_t load(const float * p) {
262
+ return vec_xl(0, p);
263
+ }
264
+ #endif
265
+
231
266
  #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
232
267
  template <> inline __m128 load(const float *p) {
233
268
  return _mm_loadu_ps(p);
@@ -394,8 +429,6 @@ class tinyBLAS {
394
429
 
395
430
  template <int RM, int RN, int BM>
396
431
  NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
397
- static std::atomic<int64_t> current_chunk;
398
-
399
432
  GGML_ASSERT(m % (RM * BM) == 0);
400
433
  const int64_t ytiles = m / (RM * BM);
401
434
  const int64_t xtiles = (n + RN -1) / RN;
@@ -410,7 +443,7 @@ class tinyBLAS {
410
443
  if (params->ith == 0) {
411
444
  GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
412
445
  // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
413
- std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed);
446
+ ggml_threadpool_chunk_set(params->threadpool, params->nth);
414
447
  }
415
448
 
416
449
  ggml_barrier(params->threadpool);
@@ -439,8 +472,7 @@ class tinyBLAS {
439
472
  GGML_ASSERT(jj == jj2);
440
473
  }
441
474
 
442
- // next step.
443
- job = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed);
475
+ job = ggml_threadpool_chunk_add(params->threadpool, 1);
444
476
  }
445
477
 
446
478
  ggml_barrier(params->threadpool);
@@ -1509,7 +1541,7 @@ class tinyBLAS_BF16_PPC {
1509
1541
  } else if constexpr(RM == 8 && RN == 4) {
1510
1542
  KERNEL_8x4(ii,jj);
1511
1543
  } else {
1512
- static_assert(false, "RN/RM values not supported");
1544
+ assert(false && "RN/RM values not supported");
1513
1545
  }
1514
1546
  }
1515
1547
 
@@ -1541,13 +1573,13 @@ class tinyBLAS_BF16_PPC {
1541
1573
  const int nth;
1542
1574
  };
1543
1575
 
1544
- template <typename TA, typename TB, typename TC>
1576
+ template <typename TA>
1545
1577
  class tinyBLAS_Q0_PPC {
1546
1578
  public:
1547
1579
  tinyBLAS_Q0_PPC(int64_t k,
1548
1580
  const TA *A, int64_t lda,
1549
- const TB *B, int64_t ldb,
1550
- TC *C, int64_t ldc,
1581
+ const block_q8_0 *B, int64_t ldb,
1582
+ float *C, int64_t ldc,
1551
1583
  int ith, int nth)
1552
1584
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1553
1585
  }
@@ -1558,8 +1590,7 @@ class tinyBLAS_Q0_PPC {
1558
1590
 
1559
1591
  private:
1560
1592
 
1561
- template<int RM, int RN>
1562
- inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1593
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1563
1594
  for (int I = 0; I < RM; I++) {
1564
1595
  for (int J = 0; J < RN; J++) {
1565
1596
  *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
@@ -1579,29 +1610,67 @@ class tinyBLAS_Q0_PPC {
1579
1610
  fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1580
1611
  }
1581
1612
  }
1582
-
1583
- template<typename VA, typename VB, int size>
1584
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
1585
- int64_t i, j;
1586
- TA *aoffset = NULL;
1587
- VA *vecOffset = NULL;
1588
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1589
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1590
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1591
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1592
- VB t1, t2, t3, t4, t5, t6, t7, t8;
1613
+ /* This function processes quantized data from block_q4_0 elements.
1614
+ * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
1615
+ * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
1616
+ * Also compute the rowsum which is required to compensate the above conversion. */
1617
+ inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
1593
1618
  const vector signed char lowMask = vec_splats((signed char)0xF);
1594
1619
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1595
1620
  const vector signed char v8 = vec_splats((signed char)0x8);
1596
- aoffset = const_cast<TA*>(a);
1597
- vecOffset = vec;
1621
+ vector signed int vsum = {0};
1622
+ vector signed int vsum2 = {0};
1623
+ c[0] = vec_and(c[1], lowMask);
1624
+ c[1] = vec_sr(c[1], v4);
1625
+ c[0] = vec_sub(c[0], v8);
1626
+ c[1] = vec_sub(c[1], v8);
1627
+ vsum = vec_sum4s(c[0], vsum);
1628
+ vsum2 = vec_sum4s(c[1], vsum2);
1629
+ vsum = vec_add(vsum, vsum2);
1630
+ *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1631
+ }
1632
+
1633
+ template <typename V1, typename V2>
1634
+ inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
1598
1635
  vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1599
1636
  vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1600
1637
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1601
1638
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1602
- vector signed int vsum = {0};
1603
- vector signed int vsum2 = {0};
1639
+ V2 t1, t2, t3, t4, t5, t6, t7, t8;
1640
+ vector unsigned char xor_vector;
1641
+ uint8_t flip_vec = 0x80;
1642
+ xor_vector = vec_splats(flip_vec);
1643
+ t1 = vec_perm(s1, s2, swiz1);
1644
+ t2 = vec_perm(s1, s2, swiz2);
1645
+ t3 = vec_perm(s3, s4, swiz1);
1646
+ t4 = vec_perm(s3, s4, swiz2);
1647
+ t5 = vec_perm(t1, t3, swiz3);
1648
+ t6 = vec_perm(t1, t3, swiz4);
1649
+ t7 = vec_perm(t2, t4, swiz3);
1650
+ t8 = vec_perm(t2, t4, swiz4);
1651
+ if (flip == true) {
1652
+ t5 = vec_xor(t5, xor_vector);
1653
+ t6 = vec_xor(t6, xor_vector);
1654
+ t7 = vec_xor(t7, xor_vector);
1655
+ t8 = vec_xor(t8, xor_vector);
1656
+ }
1657
+ vec_xst(t5, 0, vecOffset);
1658
+ vec_xst(t6, 0, vecOffset+16);
1659
+ vec_xst(t7, 0, vecOffset+32);
1660
+ vec_xst(t8, 0, vecOffset+48);
1661
+ }
1604
1662
 
1663
+ template<int size>
1664
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1665
+ int64_t i, j;
1666
+ TA *aoffset = NULL;
1667
+ int8_t *vecOffset = NULL;
1668
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1669
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1670
+ vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1671
+ vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1672
+ aoffset = const_cast<TA*>(a);
1673
+ vecOffset = vec;
1605
1674
  j = (rows >> 3);
1606
1675
  if (j > 0) {
1607
1676
  do {
@@ -1614,159 +1683,30 @@ class tinyBLAS_Q0_PPC {
1614
1683
  aoffset7 = aoffset6 + lda;
1615
1684
  aoffset8 = aoffset7 + lda;
1616
1685
  aoffset += 8 * lda;
1617
-
1618
1686
  i = (cols >> 2);
1619
1687
  if (i > 0) {
1620
1688
  do {
1621
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1622
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1623
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1624
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1625
- c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1626
- c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1627
- c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1628
- c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1629
-
1630
- c1[0] = vec_and(c1[1], lowMask);
1631
- c1[1] = vec_sr(c1[1], v4);
1632
- c1[0] = vec_sub(c1[0], v8);
1633
- c1[1] = vec_sub(c1[1], v8);
1634
- vsum = vec_sum4s(c1[0], vsum);
1635
- vsum2 = vec_sum4s(c1[1], vsum2);
1636
- vsum = vec_add(vsum, vsum2);
1637
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1638
- vsum = vec_splats(0);
1639
- vsum2 = vec_splats(0);
1640
-
1641
- c2[0] = vec_and(c2[1], lowMask);
1642
- c2[1] = vec_sr(c2[1], v4);
1643
- c2[0] = vec_sub(c2[0], v8);
1644
- c2[1] = vec_sub(c2[1], v8);
1645
- vsum = vec_sum4s(c2[0], vsum);
1646
- vsum2 = vec_sum4s(c2[1], vsum2);
1647
- vsum = vec_add(vsum, vsum2);
1648
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1649
- vsum = vec_splats(0);
1650
- vsum2 = vec_splats(0);
1651
-
1652
- c3[0] = vec_and(c3[1], lowMask);
1653
- c3[1] = vec_sr(c3[1], v4);
1654
- c3[0] = vec_sub(c3[0], v8);
1655
- c3[1] = vec_sub(c3[1], v8);
1656
- vsum = vec_sum4s(c3[0], vsum);
1657
- vsum2 = vec_sum4s(c3[1], vsum2);
1658
- vsum = vec_add(vsum, vsum2);
1659
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1660
- vsum = vec_splats(0);
1661
- vsum2 = vec_splats(0);
1662
-
1663
- c4[0] = vec_and(c4[1], lowMask);
1664
- c4[1] = vec_sr(c4[1], v4);
1665
- c4[0] = vec_sub(c4[0], v8);
1666
- c4[1] = vec_sub(c4[1], v8);
1667
- vsum = vec_sum4s(c4[0], vsum);
1668
- vsum2 = vec_sum4s(c4[1], vsum2);
1669
- vsum = vec_add(vsum, vsum2);
1670
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1671
- vsum = vec_splats(0);
1672
- vsum2 = vec_splats(0);
1673
-
1674
- c5[0] = vec_and(c5[1], lowMask);
1675
- c5[1] = vec_sr(c5[1], v4);
1676
- c5[0] = vec_sub(c5[0], v8);
1677
- c5[1] = vec_sub(c5[1], v8);
1678
- vsum = vec_sum4s(c5[0], vsum);
1679
- vsum2 = vec_sum4s(c5[1], vsum2);
1680
- vsum = vec_add(vsum, vsum2);
1681
- comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1682
- vsum = vec_splats(0);
1683
- vsum2 = vec_splats(0);
1684
-
1685
- c6[0] = vec_and(c6[1], lowMask);
1686
- c6[1] = vec_sr(c6[1], v4);
1687
- c6[0] = vec_sub(c6[0], v8);
1688
- c6[1] = vec_sub(c6[1], v8);
1689
- vsum = vec_sum4s(c6[0], vsum);
1690
- vsum2 = vec_sum4s(c6[1], vsum2);
1691
- vsum = vec_add(vsum, vsum2);
1692
- comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1693
- vsum = vec_splats(0);
1694
- vsum2 = vec_splats(0);
1695
-
1696
- c7[0] = vec_and(c7[1], lowMask);
1697
- c7[1] = vec_sr(c7[1], v4);
1698
- c7[0] = vec_sub(c7[0], v8);
1699
- c7[1] = vec_sub(c7[1], v8);
1700
- vsum = vec_sum4s(c7[0], vsum);
1701
- vsum2 = vec_sum4s(c7[1], vsum2);
1702
- vsum = vec_add(vsum, vsum2);
1703
- comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1704
- vsum = vec_splats(0);
1705
- vsum2 = vec_splats(0);
1706
-
1707
- c8[0] = vec_and(c8[1], lowMask);
1708
- c8[1] = vec_sr(c8[1], v4);
1709
- c8[0] = vec_sub(c8[0], v8);
1710
- c8[1] = vec_sub(c8[1], v8);
1711
- vsum = vec_sum4s(c8[0], vsum);
1712
- vsum2 = vec_sum4s(c8[1], vsum2);
1713
- vsum = vec_add(vsum, vsum2);
1714
- comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1715
- vsum = vec_splats(0);
1716
- vsum2 = vec_splats(0);
1717
-
1718
- t1 = vec_perm(c1[0], c2[0], swiz1);
1719
- t2 = vec_perm(c1[0], c2[0], swiz2);
1720
- t3 = vec_perm(c3[0], c4[0], swiz1);
1721
- t4 = vec_perm(c3[0], c4[0], swiz2);
1722
- t5 = vec_perm(t1, t3, swiz3);
1723
- t6 = vec_perm(t1, t3, swiz4);
1724
- t7 = vec_perm(t2, t4, swiz3);
1725
- t8 = vec_perm(t2, t4, swiz4);
1726
- vec_xst(t5, 0, vecOffset);
1727
- vec_xst(t6, 0, vecOffset+16);
1728
- vec_xst(t7, 0, vecOffset+32);
1729
- vec_xst(t8, 0, vecOffset+48);
1730
-
1731
- t1 = vec_perm(c1[1], c2[1], swiz1);
1732
- t2 = vec_perm(c1[1], c2[1], swiz2);
1733
- t3 = vec_perm(c3[1], c4[1], swiz1);
1734
- t4 = vec_perm(c3[1], c4[1], swiz2);
1735
- t5 = vec_perm(t1, t3, swiz3);
1736
- t6 = vec_perm(t1, t3, swiz4);
1737
- t7 = vec_perm(t2, t4, swiz3);
1738
- t8 = vec_perm(t2, t4, swiz4);
1739
- vec_xst(t5, 0, vecOffset+64);
1740
- vec_xst(t6, 0, vecOffset+80);
1741
- vec_xst(t7, 0, vecOffset+96);
1742
- vec_xst(t8, 0, vecOffset+112);
1743
-
1744
- t1 = vec_perm(c5[0], c6[0], swiz1);
1745
- t2 = vec_perm(c5[0], c6[0], swiz2);
1746
- t3 = vec_perm(c7[0], c8[0], swiz1);
1747
- t4 = vec_perm(c7[0], c8[0], swiz2);
1748
- t5 = vec_perm(t1, t3, swiz3);
1749
- t6 = vec_perm(t1, t3, swiz4);
1750
- t7 = vec_perm(t2, t4, swiz3);
1751
- t8 = vec_perm(t2, t4, swiz4);
1752
- vec_xst(t5, 0, vecOffset+128);
1753
- vec_xst(t6, 0, vecOffset+144);
1754
- vec_xst(t7, 0, vecOffset+160);
1755
- vec_xst(t8, 0, vecOffset+176);
1756
-
1757
- t1 = vec_perm(c5[1], c6[1], swiz1);
1758
- t2 = vec_perm(c5[1], c6[1], swiz2);
1759
- t3 = vec_perm(c7[1], c8[1], swiz1);
1760
- t4 = vec_perm(c7[1], c8[1], swiz2);
1761
- t5 = vec_perm(t1, t3, swiz3);
1762
- t6 = vec_perm(t1, t3, swiz4);
1763
- t7 = vec_perm(t2, t4, swiz3);
1764
- t8 = vec_perm(t2, t4, swiz4);
1765
- vec_xst(t5, 0, vecOffset+192);
1766
- vec_xst(t6, 0, vecOffset+208);
1767
- vec_xst(t7, 0, vecOffset+224);
1768
- vec_xst(t8, 0, vecOffset+240);
1769
-
1689
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1690
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1691
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1692
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1693
+ c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1694
+ c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1695
+ c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1696
+ c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1697
+
1698
+ process_q4_elements(c1, &comparray[0]);
1699
+ process_q4_elements(c2, &comparray[1]);
1700
+ process_q4_elements(c3, &comparray[2]);
1701
+ process_q4_elements(c4, &comparray[3]);
1702
+ process_q4_elements(c5, &comparray[4]);
1703
+ process_q4_elements(c6, &comparray[5]);
1704
+ process_q4_elements(c7, &comparray[6]);
1705
+ process_q4_elements(c8, &comparray[7]);
1706
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1707
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1708
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
1709
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
1770
1710
  aoffset1 += lda;
1771
1711
  aoffset2 += lda;
1772
1712
  aoffset3 += lda;
@@ -1789,85 +1729,20 @@ class tinyBLAS_Q0_PPC {
1789
1729
  aoffset3 = aoffset2 + lda;
1790
1730
  aoffset4 = aoffset3 + lda;
1791
1731
  aoffset += 4 * lda;
1792
-
1793
1732
  i = (cols >> 2);
1794
1733
  if (i > 0) {
1795
1734
  do {
1796
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1797
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1798
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1799
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1800
-
1801
- c1[0] = vec_and(c1[1], lowMask);
1802
- c1[1] = vec_sr(c1[1], v4);
1803
- c1[0] = vec_sub(c1[0], v8);
1804
- c1[1] = vec_sub(c1[1], v8);
1805
- vsum = vec_sum4s(c1[0], vsum);
1806
- vsum2 = vec_sum4s(c1[1], vsum2);
1807
- vsum = vec_add(vsum, vsum2);
1808
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1809
- vsum = vec_splats(0);
1810
- vsum2 = vec_splats(0);
1811
-
1812
- c2[0] = vec_and(c2[1], lowMask);
1813
- c2[1] = vec_sr(c2[1], v4);
1814
- c2[0] = vec_sub(c2[0], v8);
1815
- c2[1] = vec_sub(c2[1], v8);
1816
- vsum = vec_sum4s(c2[0], vsum);
1817
- vsum2 = vec_sum4s(c2[1], vsum2);
1818
- vsum = vec_add(vsum, vsum2);
1819
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1820
- vsum = vec_splats(0);
1821
- vsum2 = vec_splats(0);
1822
-
1823
- c3[0] = vec_and(c3[1], lowMask);
1824
- c3[1] = vec_sr(c3[1], v4);
1825
- c3[0] = vec_sub(c3[0], v8);
1826
- c3[1] = vec_sub(c3[1], v8);
1827
- vsum = vec_sum4s(c3[0], vsum);
1828
- vsum2 = vec_sum4s(c3[1], vsum2);
1829
- vsum = vec_add(vsum, vsum2);
1830
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1831
- vsum = vec_splats(0);
1832
- vsum2 = vec_splats(0);
1833
-
1834
- c4[0] = vec_and(c4[1], lowMask);
1835
- c4[1] = vec_sr(c4[1], v4);
1836
- c4[0] = vec_sub(c4[0], v8);
1837
- c4[1] = vec_sub(c4[1], v8);
1838
- vsum = vec_sum4s(c4[0], vsum);
1839
- vsum2 = vec_sum4s(c4[1], vsum2);
1840
- vsum = vec_add(vsum, vsum2);
1841
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1842
- vsum = vec_splats(0);
1843
- vsum2 = vec_splats( 0);
1844
-
1845
- t1 = vec_perm(c1[0], c2[0], swiz1);
1846
- t2 = vec_perm(c1[0], c2[0], swiz2);
1847
- t3 = vec_perm(c3[0], c4[0], swiz1);
1848
- t4 = vec_perm(c3[0], c4[0], swiz2);
1849
- t5 = vec_perm(t1, t3, swiz3);
1850
- t6 = vec_perm(t1, t3, swiz4);
1851
- t7 = vec_perm(t2, t4, swiz3);
1852
- t8 = vec_perm(t2, t4, swiz4);
1853
- vec_xst(t5, 0, vecOffset);
1854
- vec_xst(t6, 0, vecOffset+16);
1855
- vec_xst(t7, 0, vecOffset+32);
1856
- vec_xst(t8, 0, vecOffset+48);
1857
-
1858
- t1 = vec_perm(c1[1], c2[1], swiz1);
1859
- t2 = vec_perm(c1[1], c2[1], swiz2);
1860
- t3 = vec_perm(c3[1], c4[1], swiz1);
1861
- t4 = vec_perm(c3[1], c4[1], swiz2);
1862
- t5 = vec_perm(t1, t3, swiz3);
1863
- t6 = vec_perm(t1, t3, swiz4);
1864
- t7 = vec_perm(t2, t4, swiz3);
1865
- t8 = vec_perm(t2, t4, swiz4);
1866
- vec_xst(t5, 0, vecOffset+64);
1867
- vec_xst(t6, 0, vecOffset+80);
1868
- vec_xst(t7, 0, vecOffset+96);
1869
- vec_xst(t8, 0, vecOffset+112);
1870
-
1735
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1736
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1737
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1738
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1739
+
1740
+ process_q4_elements(c1, &comparray[0]);
1741
+ process_q4_elements(c2, &comparray[1]);
1742
+ process_q4_elements(c3, &comparray[2]);
1743
+ process_q4_elements(c4, &comparray[3]);
1744
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1745
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1871
1746
  aoffset1 += lda;
1872
1747
  aoffset2 += lda;
1873
1748
  aoffset3 += lda;
@@ -1886,80 +1761,17 @@ class tinyBLAS_Q0_PPC {
1886
1761
  if (i > 0) {
1887
1762
  do {
1888
1763
  switch(rows) {
1889
- case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1890
- case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1891
- case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1764
+ case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1765
+ case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1766
+ case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1892
1767
  break;
1893
1768
  }
1894
- c1[0] = vec_and(c1[1], lowMask);
1895
- c1[1] = vec_sr(c1[1], v4);
1896
- c1[0] = vec_sub(c1[0], v8);
1897
- c1[1] = vec_sub(c1[1], v8);
1898
- vsum = vec_sum4s(c1[0], vsum);
1899
- vsum2 = vec_sum4s(c1[1], vsum2);
1900
- vsum = vec_add(vsum, vsum2);
1901
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1902
- vsum = vec_splats(0);
1903
- vsum2 = vec_splats(0);
1904
-
1905
- c2[0] = vec_and(c2[1], lowMask);
1906
- c2[1] = vec_sr(c2[1], v4);
1907
- c2[0] = vec_sub(c2[0], v8);
1908
- c2[1] = vec_sub(c2[1], v8);
1909
- vsum = vec_sum4s(c2[0], vsum);
1910
- vsum2 = vec_sum4s(c2[1], vsum2);
1911
- vsum = vec_add(vsum, vsum2);
1912
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1913
- vsum = vec_splats(0);
1914
- vsum2 = vec_splats(0);
1915
-
1916
- c3[0] = vec_and(c3[1], lowMask);
1917
- c3[1] = vec_sr(c3[1], v4);
1918
- c3[0] = vec_sub(c3[0], v8);
1919
- c3[1] = vec_sub(c3[1], v8);
1920
- vsum = vec_sum4s(c3[0], vsum);
1921
- vsum2 = vec_sum4s(c3[1], vsum2);
1922
- vsum = vec_add(vsum, vsum2);
1923
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1924
- vsum = vec_splats(0);
1925
- vsum2 = vec_splats(0);
1926
-
1927
- c4[0] = vec_and(c4[1], lowMask);
1928
- c4[1] = vec_sr(c4[1], v4);
1929
- c4[0] = vec_sub(c4[0], v8);
1930
- c4[1] = vec_sub(c4[1], v8);
1931
- vsum = vec_sum4s(c4[0], vsum);
1932
- vsum2 = vec_sum4s(c4[1], vsum2);
1933
- vsum = vec_add(vsum, vsum2);
1934
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1935
- vsum = vec_splats(0);
1936
- vsum2 = vec_splats(0);
1937
-
1938
- t1 = vec_perm(c1[0], c2[0], swiz1);
1939
- t2 = vec_perm(c1[0], c2[0], swiz2);
1940
- t3 = vec_perm(c3[0], c4[0], swiz1);
1941
- t4 = vec_perm(c3[0], c4[0], swiz2);
1942
- t5 = vec_perm(t1, t3, swiz3);
1943
- t6 = vec_perm(t1, t3, swiz4);
1944
- t7 = vec_perm(t2, t4, swiz3);
1945
- t8 = vec_perm(t2, t4, swiz4);
1946
- vec_xst(t5, 0, vecOffset);
1947
- vec_xst(t6, 0, vecOffset+16);
1948
- vec_xst(t7, 0, vecOffset+32);
1949
- vec_xst(t8, 0, vecOffset+48);
1950
-
1951
- t1 = vec_perm(c1[1], c2[1], swiz1);
1952
- t2 = vec_perm(c1[1], c2[1], swiz2);
1953
- t3 = vec_perm(c3[1], c4[1], swiz1);
1954
- t4 = vec_perm(c3[1], c4[1], swiz2);
1955
- t5 = vec_perm(t1, t3, swiz3);
1956
- t6 = vec_perm(t1, t3, swiz4);
1957
- t7 = vec_perm(t2, t4, swiz3);
1958
- t8 = vec_perm(t2, t4, swiz4);
1959
- vec_xst(t5, 0, vecOffset+64);
1960
- vec_xst(t6, 0, vecOffset+80);
1961
- vec_xst(t7, 0, vecOffset+96);
1962
- vec_xst(t8, 0, vecOffset+112);
1769
+ process_q4_elements(c1, &comparray[0]);
1770
+ process_q4_elements(c2, &comparray[1]);
1771
+ process_q4_elements(c3, &comparray[2]);
1772
+ process_q4_elements(c4, &comparray[3]);
1773
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1774
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1963
1775
  aoffset1 += lda;
1964
1776
  aoffset2 += lda;
1965
1777
  aoffset3 += lda;
@@ -1969,146 +1781,40 @@ class tinyBLAS_Q0_PPC {
1969
1781
  }
1970
1782
  }
1971
1783
  }
1972
-
1973
1784
  template<typename VA, typename VB>
1974
- void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1785
+ void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1975
1786
  int64_t i, j;
1976
- TB *aoffset = NULL;
1787
+ block_q8_0 *aoffset = NULL;
1977
1788
  VA *vecOffset = NULL;
1978
- TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1979
- TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1980
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1981
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1982
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
1983
- VB t1, t2, t3, t4, t5, t6, t7, t8;
1984
- vector unsigned char xor_vector;
1985
- uint8_t flip_vec = 0x80;
1986
- xor_vector = vec_splats(flip_vec);
1987
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1988
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1989
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1990
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1991
-
1992
- aoffset = const_cast<TB*>(a);
1789
+ block_q8_0* aoffsets[8];
1790
+ __vector_pair arr[8];
1791
+ VB c[8][2] = {0};
1792
+ VB c1[8] = {0}; VB c2[8] = {0};
1793
+ aoffset = const_cast<block_q8_0*>(a);
1993
1794
  vecOffset = vec;
1994
1795
  j = (rows >> 3);
1995
1796
  if (j > 0) {
1996
1797
  do {
1997
- aoffset1 = aoffset;
1998
- aoffset2 = aoffset1 + lda;
1999
- aoffset3 = aoffset2 + lda;
2000
- aoffset4 = aoffset3 + lda;
2001
- aoffset5 = aoffset4 + lda;
2002
- aoffset6 = aoffset5 + lda;
2003
- aoffset7 = aoffset6 + lda;
2004
- aoffset8 = aoffset7 + lda;
1798
+ aoffsets[0] = aoffset;
1799
+ for (int it = 1; it < 8; it++)
1800
+ aoffsets[it] = aoffsets[it-1] + lda;
2005
1801
  aoffset += 8 * lda;
2006
1802
 
2007
1803
  i = (cols >> 3);
2008
1804
  if (i > 0) {
2009
1805
  do {
2010
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2011
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2012
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2013
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2014
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
2015
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
2016
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
2017
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
2018
-
2019
- __builtin_vsx_disassemble_pair(c1, &C1);
2020
- __builtin_vsx_disassemble_pair(c2, &C2);
2021
- __builtin_vsx_disassemble_pair(c3, &C3);
2022
- __builtin_vsx_disassemble_pair(c4, &C4);
2023
- __builtin_vsx_disassemble_pair(c5, &C5);
2024
- __builtin_vsx_disassemble_pair(c6, &C6);
2025
- __builtin_vsx_disassemble_pair(c7, &C7);
2026
- __builtin_vsx_disassemble_pair(c8, &C8);
2027
-
2028
- t1 = vec_perm(c1[0], c2[0], swiz1);
2029
- t2 = vec_perm(c1[0], c2[0], swiz2);
2030
- t3 = vec_perm(c3[0], c4[0], swiz1);
2031
- t4 = vec_perm(c3[0], c4[0], swiz2);
2032
- t5 = vec_perm(t1, t3, swiz3);
2033
- t6 = vec_perm(t1, t3, swiz4);
2034
- t7 = vec_perm(t2, t4, swiz3);
2035
- t8 = vec_perm(t2, t4, swiz4);
2036
- if (flip == true) {
2037
- t5 = vec_xor(t5, xor_vector);
2038
- t6 = vec_xor(t6, xor_vector);
2039
- t7 = vec_xor(t7, xor_vector);
2040
- t8 = vec_xor(t8, xor_vector);
2041
- }
2042
- vec_xst(t5, 0, vecOffset);
2043
- vec_xst(t6, 0, vecOffset+16);
2044
- vec_xst(t7, 0, vecOffset+32);
2045
- vec_xst(t8, 0, vecOffset+48);
2046
-
2047
- t1 = vec_perm(c1[1], c2[1], swiz1);
2048
- t2 = vec_perm(c1[1], c2[1], swiz2);
2049
- t3 = vec_perm(c3[1], c4[1], swiz1);
2050
- t4 = vec_perm(c3[1], c4[1], swiz2);
2051
- t5 = vec_perm(t1, t3, swiz3);
2052
- t6 = vec_perm(t1, t3, swiz4);
2053
- t7 = vec_perm(t2, t4, swiz3);
2054
- t8 = vec_perm(t2, t4, swiz4);
2055
- if (flip == true) {
2056
- t5 = vec_xor(t5, xor_vector);
2057
- t6 = vec_xor(t6, xor_vector);
2058
- t7 = vec_xor(t7, xor_vector);
2059
- t8 = vec_xor(t8, xor_vector);
1806
+ for (int it = 0; it < 8; it++) {
1807
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1808
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1809
+ c1[it] = c[it][0];
1810
+ c2[it] = c[it][1];
2060
1811
  }
2061
- vec_xst(t5, 0, vecOffset+64);
2062
- vec_xst(t6, 0, vecOffset+80);
2063
- vec_xst(t7, 0, vecOffset+96);
2064
- vec_xst(t8, 0, vecOffset+112);
2065
-
2066
- t1 = vec_perm(c5[0], c6[0], swiz1);
2067
- t2 = vec_perm(c5[0], c6[0], swiz2);
2068
- t3 = vec_perm(c7[0], c8[0], swiz1);
2069
- t4 = vec_perm(c7[0], c8[0], swiz2);
2070
- t5 = vec_perm(t1, t3, swiz3);
2071
- t6 = vec_perm(t1, t3, swiz4);
2072
- t7 = vec_perm(t2, t4, swiz3);
2073
- t8 = vec_perm(t2, t4, swiz4);
2074
- if (flip == true) {
2075
- t5 = vec_xor(t5, xor_vector);
2076
- t6 = vec_xor(t6, xor_vector);
2077
- t7 = vec_xor(t7, xor_vector);
2078
- t8 = vec_xor(t8, xor_vector);
2079
- }
2080
- vec_xst(t5, 0, vecOffset+128);
2081
- vec_xst(t6, 0, vecOffset+144);
2082
- vec_xst(t7, 0, vecOffset+160);
2083
- vec_xst(t8, 0, vecOffset+176);
2084
-
2085
- t1 = vec_perm(c5[1], c6[1], swiz1);
2086
- t2 = vec_perm(c5[1], c6[1], swiz2);
2087
- t3 = vec_perm(c7[1], c8[1], swiz1);
2088
- t4 = vec_perm(c7[1], c8[1], swiz2);
2089
- t5 = vec_perm(t1, t3, swiz3);
2090
- t6 = vec_perm(t1, t3, swiz4);
2091
- t7 = vec_perm(t2, t4, swiz3);
2092
- t8 = vec_perm(t2, t4, swiz4);
2093
- if (flip == true) {
2094
- t5 = vec_xor(t5, xor_vector);
2095
- t6 = vec_xor(t6, xor_vector);
2096
- t7 = vec_xor(t7, xor_vector);
2097
- t8 = vec_xor(t8, xor_vector);
2098
- }
2099
- vec_xst(t5, 0, vecOffset+192);
2100
- vec_xst(t6, 0, vecOffset+208);
2101
- vec_xst(t7, 0, vecOffset+224);
2102
- vec_xst(t8, 0, vecOffset+240);
2103
-
2104
- aoffset1 += lda;
2105
- aoffset2 += lda;
2106
- aoffset3 += lda;
2107
- aoffset4 += lda;
2108
- aoffset5 += lda;
2109
- aoffset6 += lda;
2110
- aoffset7 += lda;
2111
- aoffset8 += lda;
1812
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1813
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1814
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1815
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1816
+ for (int it = 0; it < 8; it++)
1817
+ aoffsets[it] += lda;
2112
1818
  vecOffset += 256;
2113
1819
  i--;
2114
1820
  } while(i > 0);
@@ -2118,129 +1824,53 @@ class tinyBLAS_Q0_PPC {
2118
1824
  }
2119
1825
 
2120
1826
  if (rows & 4) {
2121
- aoffset1 = aoffset;
2122
- aoffset2 = aoffset1 + lda;
2123
- aoffset3 = aoffset2 + lda;
2124
- aoffset4 = aoffset3 + lda;
2125
- aoffset += 4 * lda;
2126
-
1827
+ aoffsets[0] = aoffset;
1828
+ for (int it = 1; it < 4; it++ )
1829
+ aoffsets[it] = aoffsets[it-1] + lda;
1830
+ aoffset += 4 * lda;
2127
1831
  i = (cols >> 3);
2128
1832
  if (i > 0) {
2129
1833
  do {
2130
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2131
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2132
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2133
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2134
-
2135
- __builtin_vsx_disassemble_pair(c1, &C1);
2136
- __builtin_vsx_disassemble_pair(c2, &C2);
2137
- __builtin_vsx_disassemble_pair(c3, &C3);
2138
- __builtin_vsx_disassemble_pair(c4, &C4);
2139
-
2140
- t1 = vec_perm(c1[0], c2[0], swiz1);
2141
- t2 = vec_perm(c1[0], c2[0], swiz2);
2142
- t3 = vec_perm(c3[0], c4[0], swiz1);
2143
- t4 = vec_perm(c3[0], c4[0], swiz2);
2144
- t5 = vec_perm(t1, t3, swiz3);
2145
- t6 = vec_perm(t1, t3, swiz4);
2146
- t7 = vec_perm(t2, t4, swiz3);
2147
- t8 = vec_perm(t2, t4, swiz4);
2148
- if (flip == true) {
2149
- t5 = vec_xor(t5, xor_vector);
2150
- t6 = vec_xor(t6, xor_vector);
2151
- t7 = vec_xor(t7, xor_vector);
2152
- t8 = vec_xor(t8, xor_vector);
1834
+ for (int it = 0; it < 4; it++) {
1835
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1836
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1837
+ c1[it] = c[it][0];
1838
+ c2[it] = c[it][1];
2153
1839
  }
2154
- vec_xst(t5, 0, vecOffset);
2155
- vec_xst(t6, 0, vecOffset+16);
2156
- vec_xst(t7, 0, vecOffset+32);
2157
- vec_xst(t8, 0, vecOffset+48);
2158
-
2159
- t1 = vec_perm(c1[1], c2[1], swiz1);
2160
- t2 = vec_perm(c1[1], c2[1], swiz2);
2161
- t3 = vec_perm(c3[1], c4[1], swiz1);
2162
- t4 = vec_perm(c3[1], c4[1], swiz2);
2163
- t5 = vec_perm(t1, t3, swiz3);
2164
- t6 = vec_perm(t1, t3, swiz4);
2165
- t7 = vec_perm(t2, t4, swiz3);
2166
- t8 = vec_perm(t2, t4, swiz4);
2167
- if (flip == true) {
2168
- t5 = vec_xor(t5, xor_vector);
2169
- t6 = vec_xor(t6, xor_vector);
2170
- t7 = vec_xor(t7, xor_vector);
2171
- t8 = vec_xor(t8, xor_vector);
1840
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1841
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1842
+ for (int it = 0; it < 4; it++) {
1843
+ aoffsets[it] += lda;
2172
1844
  }
2173
- vec_xst(t5, 0, vecOffset+64);
2174
- vec_xst(t6, 0, vecOffset+80);
2175
- vec_xst(t7, 0, vecOffset+96);
2176
- vec_xst(t8, 0, vecOffset+112);
2177
-
2178
- aoffset1 += lda;
2179
- aoffset2 += lda;
2180
- aoffset3 += lda;
2181
- aoffset4 += lda;
2182
1845
  vecOffset += 128;
2183
1846
  i--;
2184
1847
  } while(i > 0);
2185
1848
  }
2186
1849
  }
1850
+
2187
1851
  if (rows & 3) {
2188
- aoffset1 = aoffset;
2189
- aoffset2 = aoffset1 + lda;
2190
- aoffset3 = aoffset2 + lda;
1852
+ aoffsets[0] = aoffset;
1853
+ for (int it = 1; it < 3; it++ )
1854
+ aoffsets[it] = aoffsets[it-1] + lda;
2191
1855
  i = (cols >> 3);
2192
1856
  if (i > 0) {
2193
1857
  do {
2194
1858
  switch(rows) {
2195
- case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2196
- __builtin_vsx_disassemble_pair(c3, &C3);
2197
- case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2198
- __builtin_vsx_disassemble_pair(c2, &C2);
2199
- case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2200
- __builtin_vsx_disassemble_pair(c1, &C1);
1859
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
1860
+ __builtin_vsx_disassemble_pair(c[2], &arr[2]);
1861
+ c1[2] = c[2][0]; c2[2] = c[2][1];
1862
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
1863
+ __builtin_vsx_disassemble_pair(c[1], &arr[1]);
1864
+ c1[1] = c[1][0]; c2[1] = c[1][1];
1865
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
1866
+ __builtin_vsx_disassemble_pair(c[0], &arr[0]);
1867
+ c1[0] = c[0][0]; c2[0] = c[0][1];
2201
1868
  break;
2202
1869
  }
2203
- t1 = vec_perm(c1[0], c2[0], swiz1);
2204
- t2 = vec_perm(c1[0], c2[0], swiz2);
2205
- t3 = vec_perm(c3[0], c4[0], swiz1);
2206
- t4 = vec_perm(c3[0], c4[0], swiz2);
2207
- t5 = vec_perm(t1, t3, swiz3);
2208
- t6 = vec_perm(t1, t3, swiz4);
2209
- t7 = vec_perm(t2, t4, swiz3);
2210
- t8 = vec_perm(t2, t4, swiz4);
2211
- if (flip == true) {
2212
- t5 = vec_xor(t5, xor_vector);
2213
- t6 = vec_xor(t6, xor_vector);
2214
- t7 = vec_xor(t7, xor_vector);
2215
- t8 = vec_xor(t8, xor_vector);
2216
- }
2217
- vec_xst(t5, 0, vecOffset);
2218
- vec_xst(t6, 0, vecOffset+16);
2219
- vec_xst(t7, 0, vecOffset+32);
2220
- vec_xst(t8, 0, vecOffset+48);
2221
-
2222
- t1 = vec_perm(c1[1], c2[1], swiz1);
2223
- t2 = vec_perm(c1[1], c2[1], swiz2);
2224
- t3 = vec_perm(c3[1], c4[1], swiz1);
2225
- t4 = vec_perm(c3[1], c4[1], swiz2);
2226
- t5 = vec_perm(t1, t3, swiz3);
2227
- t6 = vec_perm(t1, t3, swiz4);
2228
- t7 = vec_perm(t2, t4, swiz3);
2229
- t8 = vec_perm(t2, t4, swiz4);
2230
- if (flip == true) {
2231
- t5 = vec_xor(t5, xor_vector);
2232
- t6 = vec_xor(t6, xor_vector);
2233
- t7 = vec_xor(t7, xor_vector);
2234
- t8 = vec_xor(t8, xor_vector);
2235
- }
2236
- vec_xst(t5, 0, vecOffset+64);
2237
- vec_xst(t6, 0, vecOffset+80);
2238
- vec_xst(t7, 0, vecOffset+96);
2239
- vec_xst(t8, 0, vecOffset+112);
2240
-
2241
- aoffset1 += lda;
2242
- aoffset2 += lda;
2243
- aoffset3 += lda;
1870
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1871
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1872
+ for (int it = 0; it < 3; it++)
1873
+ aoffsets[it] += lda;
2244
1874
  vecOffset += 128;
2245
1875
  i--;
2246
1876
  } while(i > 0);
@@ -2249,159 +1879,42 @@ class tinyBLAS_Q0_PPC {
2249
1879
  }
2250
1880
 
2251
1881
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2252
- int64_t mc, nc, mp, np;
2253
- int m_rem = MIN(m - m0, 8);
2254
- int n_rem = MIN(n - n0, 8);
2255
- // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
2256
- // issues. After resolving them, below code will be enabled.
2257
- /*if (m_rem >= 16 && n_rem >= 8) {
2258
- mc = 16;
2259
- nc = 8;
2260
- gemm<16,8>(m0, m, n0, n);
2261
- } else if(m_rem >= 8 && n_rem >= 16) {
2262
- mc = 8;
2263
- nc = 16;
2264
- gemm<8,16>(m0, m, n0, n);
2265
- }*/
1882
+ int m_rem = MIN(m - m0, 16);
1883
+ int n_rem = MIN(n - n0, 16);
1884
+
1885
+ int mc = 0, nc = 0;
1886
+
2266
1887
  if (m_rem >= 8 && n_rem >= 8) {
2267
- mc = 8;
2268
- nc = 8;
2269
- gemm<8,8>(m0, m, n0, n);
1888
+ mc = 8;
1889
+ nc = 8;
1890
+ gemm<8, 8>(m0, m, n0, n);
2270
1891
  } else if (m_rem >= 4 && n_rem >= 8) {
2271
1892
  mc = 4;
2272
1893
  nc = 8;
2273
- gemm<4,8>(m0, m, n0, n);
1894
+ gemm<4, 8>(m0, m, n0, n);
2274
1895
  } else if (m_rem >= 8 && n_rem >= 4) {
2275
1896
  mc = 8;
2276
1897
  nc = 4;
2277
- gemm<8,4>(m0, m, n0, n);
1898
+ gemm<8, 4>(m0, m, n0, n);
2278
1899
  } else if (m_rem >= 4 && n_rem >= 4) {
2279
1900
  mc = 4;
2280
1901
  nc = 4;
2281
- gemm_small<4, 4>(m0, m, n0, n);
2282
- } else if ((m_rem < 4) && (n_rem > 4)) {
2283
- nc = 4;
2284
- switch(m_rem) {
2285
- case 1:
2286
- mc = 1;
2287
- gemm_small<1, 4>(m0, m, n0, n);
2288
- break;
2289
- case 2:
2290
- mc = 2;
2291
- gemm_small<2, 4>(m0, m, n0, n);
2292
- break;
2293
- case 3:
2294
- mc = 3;
2295
- gemm_small<3, 4>(m0, m, n0, n);
2296
- break;
2297
- default:
2298
- return;
2299
- }
2300
- } else if ((m_rem > 4) && (n_rem < 4)) {
2301
- mc = 4;
2302
- switch(n_rem) {
2303
- case 1:
2304
- nc = 1;
2305
- gemm_small<4, 1>(m0, m, n0, n);
2306
- break;
2307
- case 2:
2308
- nc = 2;
2309
- gemm_small<4, 2>(m0, m, n0, n);
2310
- break;
2311
- case 3:
2312
- nc = 3;
2313
- gemm_small<4, 3>(m0, m, n0, n);
2314
- break;
2315
- default:
2316
- return;
2317
- }
1902
+ gemm_small(m0, m, n0, n, mc, nc);
2318
1903
  } else {
2319
- switch((m_rem << 4) | n_rem) {
2320
- case 0x43:
2321
- mc = 4;
2322
- nc = 3;
2323
- gemm_small<4, 3>(m0, m, n0, n);
2324
- break;
2325
- case 0x42:
2326
- mc = 4;
2327
- nc = 2;
2328
- gemm_small<4, 2>(m0, m, n0, n);
2329
- break;
2330
- case 0x41:
2331
- mc = 4;
2332
- nc = 1;
2333
- gemm_small<4, 1>(m0, m, n0, n);
2334
- break;
2335
- case 0x34:
2336
- mc = 3;
2337
- nc = 4;
2338
- gemm_small<3, 4>(m0, m, n0, n);
2339
- break;
2340
- case 0x33:
2341
- mc = 3;
2342
- nc = 3;
2343
- gemm_small<3, 3>(m0, m, n0, n);
2344
- break;
2345
- case 0x32:
2346
- mc = 3;
2347
- nc = 2;
2348
- gemm_small<3, 2>(m0, m, n0, n);
2349
- break;
2350
- case 0x31:
2351
- mc = 3;
2352
- nc = 1;
2353
- gemm_small<3, 1>(m0, m, n0, n);
2354
- break;
2355
- case 0x24:
2356
- mc = 2;
2357
- nc = 4;
2358
- gemm_small<2, 4>(m0, m, n0, n);
2359
- break;
2360
- case 0x23:
2361
- mc = 2;
2362
- nc = 3;
2363
- gemm_small<2, 3>(m0, m, n0, n);
2364
- break;
2365
- case 0x22:
2366
- mc = 2;
2367
- nc = 2;
2368
- gemm_small<2, 2>(m0, m, n0, n);
2369
- break;
2370
- case 0x21:
2371
- mc = 2;
2372
- nc = 1;
2373
- gemm_small<2, 1>(m0, m, n0, n);
2374
- break;
2375
- case 0x14:
2376
- mc = 1;
2377
- nc = 4;
2378
- gemm_small<1, 4>(m0, m, n0, n);
2379
- break;
2380
- case 0x13:
2381
- mc = 1;
2382
- nc = 3;
2383
- gemm_small<1, 3>(m0, m, n0, n);
2384
- break;
2385
- case 0x12:
2386
- mc = 1;
2387
- nc = 2;
2388
- gemm_small<1, 2>(m0, m, n0, n);
2389
- break;
2390
- case 0x11:
2391
- mc = 1;
2392
- nc = 1;
2393
- gemm_small<1, 1>(m0, m, n0, n);
2394
- break;
2395
- default:
2396
- return;
2397
- }
1904
+ mc = (m_rem >= 4) ? 4 : m_rem;
1905
+ nc = (n_rem >= 4) ? 4 : n_rem;
1906
+ if (mc == 0 || nc == 0)
1907
+ return;
1908
+ gemm_small(m0, m, n0, n, mc, nc);
2398
1909
  }
2399
- mp = m0 + (m - m0) / mc * mc;
2400
- np = n0 + (n - n0) / nc * nc;
1910
+
1911
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
1912
+ int64_t np = n0 + ((n - n0) / nc) * nc;
2401
1913
  mnpack(mp, m, n0, np);
2402
1914
  mnpack(m0, m, np, n);
2403
1915
  }
2404
1916
 
1917
+
2405
1918
  void KERNEL_4x8(int64_t ii, int64_t jj) {
2406
1919
  vec_t vec_A[8], vec_B[16] = {0};
2407
1920
  acc_t acc_0, acc_1;
@@ -2413,9 +1926,9 @@ class tinyBLAS_Q0_PPC {
2413
1926
  __builtin_mma_xxsetaccz(&acc_0);
2414
1927
  __builtin_mma_xxsetaccz(&acc_1);
2415
1928
  if (std::is_same_v<TA, block_q4_0>) {
2416
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1929
+ packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2417
1930
  } else {
2418
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1931
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2419
1932
  }
2420
1933
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2421
1934
  for(int x = 0; x < 8; x++) {
@@ -2443,8 +1956,8 @@ class tinyBLAS_Q0_PPC {
2443
1956
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
2444
1957
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2445
1958
  }
2446
- save_res<4, 4>(ii, jj, 0, fin_res);
2447
- save_res<4, 4>(ii, jj+4, 4, fin_res);
1959
+ save_res(ii, jj, 0, fin_res);
1960
+ save_res(ii, jj+4, 4, fin_res);
2448
1961
  }
2449
1962
 
2450
1963
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2458,9 +1971,9 @@ class tinyBLAS_Q0_PPC {
2458
1971
  __builtin_mma_xxsetaccz(&acc_0);
2459
1972
  __builtin_mma_xxsetaccz(&acc_1);
2460
1973
  if (std::is_same_v<TA, block_q4_0>) {
2461
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1974
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2462
1975
  } else {
2463
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1976
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2464
1977
  }
2465
1978
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2466
1979
  for(int x = 0; x < 8; x++) {
@@ -2487,8 +2000,8 @@ class tinyBLAS_Q0_PPC {
2487
2000
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2488
2001
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2489
2002
  }
2490
- save_res<4, 4>(ii, jj, 0, fin_res);
2491
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2003
+ save_res(ii, jj, 0, fin_res);
2004
+ save_res(ii+4, jj, 4, fin_res);
2492
2005
  }
2493
2006
 
2494
2007
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2504,9 +2017,9 @@ class tinyBLAS_Q0_PPC {
2504
2017
  __builtin_mma_xxsetaccz(&acc_2);
2505
2018
  __builtin_mma_xxsetaccz(&acc_3);
2506
2019
  if (std::is_same_v<TA, block_q4_0>) {
2507
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2020
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2508
2021
  } else {
2509
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2022
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2510
2023
  }
2511
2024
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2512
2025
  for(int x = 0; x < 8; x++) {
@@ -2538,14 +2051,13 @@ class tinyBLAS_Q0_PPC {
2538
2051
  compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2539
2052
  compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2540
2053
  }
2541
- save_res<4, 4>(ii, jj, 0, fin_res);
2542
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2543
- save_res<4, 4>(ii, jj+4, 8, fin_res);
2544
- save_res<4, 4>(ii+4, jj+4, 12, fin_res);
2054
+ save_res(ii, jj, 0, fin_res);
2055
+ save_res(ii+4, jj, 4, fin_res);
2056
+ save_res(ii, jj+4, 8, fin_res);
2057
+ save_res(ii+4, jj+4, 12, fin_res);
2545
2058
  }
2546
2059
 
2547
- template<int RM, int RN>
2548
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2060
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2549
2061
  int64_t ytiles = (m - m0) / RM;
2550
2062
  int64_t xtiles = (n - n0) / RN;
2551
2063
  int64_t tiles = xtiles * ytiles;
@@ -2574,9 +2086,9 @@ class tinyBLAS_Q0_PPC {
2574
2086
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2575
2087
  __builtin_mma_xxsetaccz(&acc_0);
2576
2088
  if (isAblock_q4) {
2577
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2089
+ packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2578
2090
  } else {
2579
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2091
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2580
2092
  }
2581
2093
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2582
2094
  for(int x = 0; x < 8; x+=4) {
@@ -2609,7 +2121,7 @@ class tinyBLAS_Q0_PPC {
2609
2121
  fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2610
2122
  }
2611
2123
  }
2612
- save_res<RM, RN>(ii, jj, 0, fin_res);
2124
+ save_res(ii, jj, 0, fin_res, RM, RN);
2613
2125
  }
2614
2126
  }
2615
2127
 
@@ -2622,7 +2134,7 @@ class tinyBLAS_Q0_PPC {
2622
2134
  } else if constexpr(RM == 8 && RN == 8) {
2623
2135
  KERNEL_8x8(ii,jj);
2624
2136
  } else {
2625
- static_assert(false, "RN/RM values not supported");
2137
+ assert(false && "RN/RM values not supported");
2626
2138
  }
2627
2139
  }
2628
2140
 
@@ -2644,10 +2156,8 @@ class tinyBLAS_Q0_PPC {
2644
2156
  }
2645
2157
 
2646
2158
  const TA *const A;
2647
- const TB *const B;
2648
- TC *C;
2649
- TA *At;
2650
- TB *Bt;
2159
+ const block_q8_0 *const B;
2160
+ float *C;
2651
2161
  const int64_t k;
2652
2162
  const int64_t lda;
2653
2163
  const int64_t ldb;
@@ -2656,266 +2166,183 @@ class tinyBLAS_Q0_PPC {
2656
2166
  const int nth;
2657
2167
  };
2658
2168
 
2659
- template <typename TA, typename TB, typename TC>
2660
2169
  class tinyBLAS_PPC {
2661
2170
  public:
2662
2171
  tinyBLAS_PPC(int64_t k,
2663
- const TA *A, int64_t lda,
2664
- const TB *B, int64_t ldb,
2665
- TC *C, int64_t ldc,
2172
+ const float * A, int64_t lda,
2173
+ const float * B, int64_t ldb,
2174
+ float * C, int64_t ldc,
2666
2175
  int ith, int nth)
2667
2176
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2668
2177
  }
2669
2178
 
2670
2179
  void matmul(int64_t m, int64_t n) {
2671
- mnpack(0, m, 0, n);
2180
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2181
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2182
+ matmul_tiled(m, n, mc, nc, kc);
2183
+ } else {
2184
+ mnpack(0, m, 0, n);
2185
+ }
2672
2186
  }
2673
2187
 
2674
2188
  private:
2675
2189
 
2676
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2190
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2191
+ vec_t vec_C[4];
2192
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2193
+ for (int I = 0; I < 4; I++) {
2194
+ for (int J = 0; J < 4; J++) {
2195
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2196
+ }
2197
+ }
2198
+ }
2199
+
2200
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2201
+ vec_t vec_C[4];
2202
+ __builtin_mma_disassemble_acc(vec_C, ACC);
2203
+ for (int I = 0; I < 4; I++) {
2204
+ for (int J = 0; J < 4; J++) {
2205
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2206
+ *c_ptr += *((float *)&vec_C[I]+J);
2207
+ }
2208
+ }
2209
+ }
2210
+
2211
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2212
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2213
+ t1 = vec_mergeh(src[0], src[1]);
2214
+ t2 = vec_mergeh(src[2], src[3]);
2215
+ t3 = vec_mergel(src[0], src[1]);
2216
+ t4 = vec_mergel(src[2], src[3]);
2217
+
2218
+ t5 = vec_xxpermdi(t1, t2, 0);
2219
+ t6 = vec_xxpermdi(t1, t2, 3);
2220
+ t7 = vec_xxpermdi(t3, t4, 0);
2221
+ t8 = vec_xxpermdi(t3, t4, 3);
2222
+
2223
+ vec_xst(t5, 0, vecOffset);
2224
+ vec_xst(t6, 0, vecOffset + 4);
2225
+ vec_xst(t7, 0, vecOffset + 8);
2226
+ vec_xst(t8, 0, vecOffset + 12);
2227
+ }
2228
+
2229
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2230
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2231
+ t1 = vec_mergeh(src[0], src[1]);
2232
+ t2 = vec_mergeh(src[2], src[3]);
2233
+ t3 = vec_mergeh(src[4], src[5]);
2234
+ t4 = vec_mergeh(src[6], src[7]);
2677
2235
 
2678
- template<typename VA>
2679
- void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
2236
+ t5 = vec_xxpermdi(t1, t2, 0);
2237
+ t6 = vec_xxpermdi(t3, t4, 0);
2238
+ t7 = vec_xxpermdi(t1, t2, 3);
2239
+ t8 = vec_xxpermdi(t3, t4, 3);
2240
+
2241
+ vec_xst(t5, 0, vecOffset);
2242
+ vec_xst(t6, 0, vecOffset + 4);
2243
+ vec_xst(t7, 0, vecOffset + 8);
2244
+ vec_xst(t8, 0, vecOffset + 12);
2245
+
2246
+ t1 = vec_mergel(src[0], src[1]);
2247
+ t2 = vec_mergel(src[2], src[3]);
2248
+ t3 = vec_mergel(src[4], src[5]);
2249
+ t4 = vec_mergel(src[6], src[7]);
2250
+
2251
+ t5 = vec_xxpermdi(t1, t2, 0);
2252
+ t6 = vec_xxpermdi(t3, t4, 0);
2253
+ t7 = vec_xxpermdi(t1, t2, 3);
2254
+ t8 = vec_xxpermdi(t3, t4, 3);
2255
+
2256
+ vec_xst(t5, 0, vecOffset + 16);
2257
+ vec_xst(t6, 0, vecOffset + 20);
2258
+ vec_xst(t7, 0, vecOffset + 24);
2259
+ vec_xst(t8, 0, vecOffset + 28);
2260
+ }
2261
+
2262
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2680
2263
  int64_t i, j;
2681
- TA *aoffset = NULL, *boffset = NULL;
2682
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2683
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2684
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2685
- VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2686
- VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2687
- VA t1, t2, t3, t4, t5, t6, t7, t8;
2688
- aoffset = const_cast<TA*>(a);
2264
+ float * aoffsets[8];
2265
+ float * aoffset = NULL, * boffset = NULL;
2266
+ __vector_pair arr[8];
2267
+ vector float c[8][2] = {0};
2268
+ vector float c1[8] = {0};
2269
+ vector float c2[8] = {0};
2270
+ aoffset = const_cast<float *>(a);
2689
2271
  boffset = vec;
2690
2272
  j = (rows >> 3);
2691
2273
  if (j > 0) {
2692
-
2693
2274
  do {
2694
- aoffset1 = aoffset;
2695
- aoffset2 = aoffset1 + lda;
2696
- aoffset3 = aoffset2 + lda;
2697
- aoffset4 = aoffset3 + lda;
2698
- aoffset5 = aoffset4 + lda;
2699
- aoffset6 = aoffset5 + lda;
2700
- aoffset7 = aoffset6 + lda;
2701
- aoffset8 = aoffset7 + lda;
2275
+ aoffsets[0] = aoffset;
2276
+ for (int it = 1; it < 8; it++)
2277
+ aoffsets[it] = aoffsets[it-1] + lda;
2702
2278
  aoffset += 8 * lda;
2703
2279
  i = (cols >> 3);
2704
2280
  if (i > 0) {
2705
2281
  do {
2706
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2707
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2708
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2709
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2710
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
2711
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
2712
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
2713
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
2714
- __builtin_vsx_disassemble_pair(c1, &C1);
2715
- __builtin_vsx_disassemble_pair(c2, &C2);
2716
- __builtin_vsx_disassemble_pair(c3, &C3);
2717
- __builtin_vsx_disassemble_pair(c4, &C4);
2718
- __builtin_vsx_disassemble_pair(c5, &C5);
2719
- __builtin_vsx_disassemble_pair(c6, &C6);
2720
- __builtin_vsx_disassemble_pair(c7, &C7);
2721
- __builtin_vsx_disassemble_pair(c8, &C8);
2722
-
2723
- t1 = vec_mergeh(c1[0], c2[0]);
2724
- t2 = vec_mergeh(c3[0], c4[0]);
2725
- t3 = vec_mergeh(c5[0], c6[0]);
2726
- t4 = vec_mergeh(c7[0], c8[0]);
2727
- t5 = vec_xxpermdi(t1, t2, 0);
2728
- t6 = vec_xxpermdi(t3, t4, 0);
2729
- t7 = vec_xxpermdi(t1, t2, 3);
2730
- t8 = vec_xxpermdi(t3, t4, 3);
2731
- vec_xst(t5, 0, boffset);
2732
- vec_xst(t6, 0, boffset+4);
2733
- vec_xst(t7, 0, boffset+8);
2734
- vec_xst(t8, 0, boffset+12);
2735
-
2736
- t1 = vec_mergel(c1[0], c2[0]);
2737
- t2 = vec_mergel(c3[0], c4[0]);
2738
- t3 = vec_mergel(c5[0], c6[0]);
2739
- t4 = vec_mergel(c7[0], c8[0]);
2740
- t5 = vec_xxpermdi(t1, t2, 0);
2741
- t6 = vec_xxpermdi(t3, t4, 0);
2742
- t7 = vec_xxpermdi(t1, t2, 3);
2743
- t8 = vec_xxpermdi(t3, t4, 3);
2744
- vec_xst(t5, 0, boffset+16);
2745
- vec_xst(t6, 0, boffset+20);
2746
- vec_xst(t7, 0, boffset+24);
2747
- vec_xst(t8, 0, boffset+28);
2748
-
2749
- t1 = vec_mergeh(c1[1], c2[1]);
2750
- t2 = vec_mergeh(c3[1], c4[1]);
2751
- t3 = vec_mergeh(c5[1], c6[1]);
2752
- t4 = vec_mergeh(c7[1], c8[1]);
2753
- t5 = vec_xxpermdi(t1, t2, 0);
2754
- t6 = vec_xxpermdi(t3, t4, 0);
2755
- t7 = vec_xxpermdi(t1, t2, 3);
2756
- t8 = vec_xxpermdi(t3, t4, 3);
2757
- vec_xst(t5, 0, boffset+32);
2758
- vec_xst(t6, 0, boffset+36);
2759
- vec_xst(t7, 0, boffset+40);
2760
- vec_xst(t8, 0, boffset+44);
2761
-
2762
- t1 = vec_mergel(c1[1], c2[1]);
2763
- t2 = vec_mergel(c3[1], c4[1]);
2764
- t3 = vec_mergel(c5[1], c6[1]);
2765
- t4 = vec_mergel(c7[1], c8[1]);
2766
- t5 = vec_xxpermdi(t1, t2, 0);
2767
- t6 = vec_xxpermdi(t3, t4, 0);
2768
- t7 = vec_xxpermdi(t1, t2, 3);
2769
- t8 = vec_xxpermdi(t3, t4, 3);
2770
- vec_xst(t5, 0, boffset+48);
2771
- vec_xst(t6, 0, boffset+52);
2772
- vec_xst(t7, 0, boffset+56);
2773
- vec_xst(t8, 0, boffset+60);
2774
-
2775
- aoffset1 += 8*lda;
2776
- aoffset2 += 8*lda;
2777
- aoffset3 += 8*lda;
2778
- aoffset4 += 8*lda;
2282
+ for (int it = 0; it < 8; it++) {
2283
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2284
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2285
+ c1[it] = c[it][0];
2286
+ c2[it] = c[it][1];
2287
+ }
2288
+
2289
+ vector_permute_store_8(c1, boffset);
2290
+ vector_permute_store_8(c2, boffset + 32);
2779
2291
  boffset += 64;
2780
2292
  i--;
2293
+ if (i > 0) {
2294
+ for (int it = 0; it < 8; it++) {
2295
+ aoffsets[it] = aoffsets[it] + 8;
2296
+ }
2297
+ }
2781
2298
  } while(i > 0);
2782
2299
  }
2783
2300
  if (cols & 4) {
2784
- c1[0] = vec_xl(0, aoffset1);
2785
- c2[0] = vec_xl(0, aoffset2);
2786
- c3[0] = vec_xl(0, aoffset3);
2787
- c4[0] = vec_xl(0, aoffset4);
2788
- c5[0] = vec_xl(0, aoffset5);
2789
- c6[0] = vec_xl(0, aoffset6);
2790
- c7[0] = vec_xl(0, aoffset7);
2791
- c8[0] = vec_xl(0, aoffset8);
2792
-
2793
- t1 = vec_mergeh(c1[0], c2[0]);
2794
- t2 = vec_mergeh(c3[0], c4[0]);
2795
- t3 = vec_mergeh(c5[0], c6[0]);
2796
- t4 = vec_mergeh(c7[0], c8[0]);
2797
- t5 = vec_xxpermdi(t1, t2, 0);
2798
- t6 = vec_xxpermdi(t3, t4, 0);
2799
- t7 = vec_xxpermdi(t1, t2, 3);
2800
- t8 = vec_xxpermdi(t3, t4, 3);
2801
- vec_xst(t5, 0, boffset);
2802
- vec_xst(t6, 0, boffset+4);
2803
- vec_xst(t7, 0, boffset+8);
2804
- vec_xst(t8, 0, boffset+12);
2805
-
2806
- t1 = vec_mergel(c1[0], c2[0]);
2807
- t2 = vec_mergel(c3[0], c4[0]);
2808
- t3 = vec_mergel(c5[0], c6[0]);
2809
- t4 = vec_mergel(c7[0], c8[0]);
2810
- t5 = vec_xxpermdi(t1, t2, 0);
2811
- t6 = vec_xxpermdi(t3, t4, 0);
2812
- t7 = vec_xxpermdi(t1, t2, 3);
2813
- t8 = vec_xxpermdi(t3, t4, 3);
2814
- vec_xst(t5, 0, boffset+16);
2815
- vec_xst(t6, 0, boffset+20);
2816
- vec_xst(t7, 0, boffset+24);
2817
- vec_xst(t8, 0, boffset+28);
2301
+ for (int it = 0; it < 8 ; it++)
2302
+ c1[it] = vec_xl(0, aoffsets[it]);
2303
+ vector_permute_store_8(c1, boffset);
2818
2304
  }
2819
2305
  j--;
2820
2306
  } while(j > 0);
2821
2307
  }
2822
2308
 
2823
2309
  if (rows & 4) {
2824
- aoffset1 = aoffset;
2825
- aoffset2 = aoffset1 + lda;
2826
- aoffset3 = aoffset2 + lda;
2827
- aoffset4 = aoffset3 + lda;
2310
+ aoffsets[0] = aoffset;
2311
+ for (int it = 1; it < 4; it++)
2312
+ aoffsets[it] = aoffsets[it-1] + lda;
2828
2313
  aoffset += 4 * lda;
2829
2314
  i = (cols >> 3);
2830
2315
  if (i > 0) {
2831
2316
  do {
2832
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2833
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2834
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2835
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2836
- __builtin_vsx_disassemble_pair(c1, &C1);
2837
- __builtin_vsx_disassemble_pair(c2, &C2);
2838
- __builtin_vsx_disassemble_pair(c3, &C3);
2839
- __builtin_vsx_disassemble_pair(c4, &C4);
2840
-
2841
- t1 = vec_mergeh(c1[0], c2[0]);
2842
- t2 = vec_mergeh(c3[0], c4[0]);
2843
- t3 = vec_mergel(c1[0], c2[0]);
2844
- t4 = vec_mergel(c3[0], c4[0]);
2845
- t5 = vec_xxpermdi(t1, t2, 0);
2846
- t6 = vec_xxpermdi(t1, t2, 3);
2847
- t7 = vec_xxpermdi(t3, t4, 0);
2848
- t8 = vec_xxpermdi(t3, t4, 3);
2849
- vec_xst(t5, 0, boffset);
2850
- vec_xst(t6, 0, boffset+4);
2851
- vec_xst(t7, 0, boffset+8);
2852
- vec_xst(t8, 0, boffset+12);
2853
-
2854
- t1 = vec_mergeh(c1[1], c2[1]);
2855
- t2 = vec_mergeh(c3[1], c4[1]);
2856
- t3 = vec_mergel(c1[1], c2[1]);
2857
- t4 = vec_mergel(c3[1], c4[1]);
2858
- t5 = vec_xxpermdi(t1, t2, 0);
2859
- t6 = vec_xxpermdi(t1, t2, 3);
2860
- t7 = vec_xxpermdi(t3, t4, 0);
2861
- t8 = vec_xxpermdi(t3, t4, 3);
2862
- vec_xst(t5, 0, boffset+16);
2863
- vec_xst(t6, 0, boffset+20);
2864
- vec_xst(t7, 0, boffset+24);
2865
- vec_xst(t8, 0, boffset+28);
2866
-
2867
- aoffset1 += 8*lda;
2868
- aoffset2 += 8*lda;
2869
- aoffset3 += 8*lda;
2870
- aoffset4 += 8*lda;
2317
+ for (int it = 0; it < 4; it++) {
2318
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2319
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2320
+ c1[it] = c[it][0];
2321
+ c2[it] = c[it][1];
2322
+ }
2323
+ vector_permute_store_4(c1, boffset);
2324
+ vector_permute_store_4(c2, boffset + 16);
2325
+ for (int it = 0; it < 4; it++)
2326
+ aoffsets[it] += 8 * lda;
2871
2327
  boffset += 32;
2872
2328
  i--;
2873
2329
  } while(i > 0);
2874
2330
  }
2875
2331
 
2876
2332
  if (cols & 4) {
2877
- c1[0] = vec_xl(0, aoffset1);
2878
- c2[0] = vec_xl(0, aoffset2);
2879
- c3[0] = vec_xl(0, aoffset3);
2880
- c4[0] = vec_xl(0, aoffset4);
2881
-
2882
- t1 = vec_mergeh(c1[0], c2[0]);
2883
- t2 = vec_mergeh(c3[0], c4[0]);
2884
- t3 = vec_xxpermdi(t1, t2, 0);
2885
- t4 = vec_xxpermdi(t1, t2, 3);
2886
- vec_xst(t3, 0, boffset);
2887
- vec_xst(t4, 0, boffset+4);
2888
-
2889
- t1 = vec_mergel(c1[0], c2[0]);
2890
- t2 = vec_mergel(c3[0], c4[0]);
2891
- t3 = vec_xxpermdi(t1, t2, 0);
2892
- t4 = vec_xxpermdi(t1, t2, 3);
2893
- vec_xst(t3, 0, boffset+8);
2894
- vec_xst(t4, 0, boffset+12);
2333
+ for (int it = 0; it < 4; it++)
2334
+ c1[it] = vec_xl(0, aoffsets[it]);
2335
+ vector_permute_store_4(c1, boffset);
2895
2336
  }
2896
2337
  }
2897
2338
  if (rows & 3) {
2898
- aoffset1 = aoffset;
2899
- aoffset2 = aoffset1 + lda;
2900
- aoffset3 = aoffset2 + lda;
2339
+ aoffsets[0] = aoffset;
2340
+ for (int it = 1; it < 3; it++)
2341
+ aoffsets[it] = aoffsets[it-1] + lda;
2901
2342
  if (cols & 4) {
2902
- c1[0] = vec_xl(0, aoffset1);
2903
- c2[0] = vec_xl(0, aoffset2);
2904
- c3[0] = vec_xl(0, aoffset3);
2905
-
2906
- t1 = vec_mergeh(c1[0], c2[0]);
2907
- t2 = vec_mergeh(c3[0], c4[0]);
2908
- t3 = vec_xxpermdi(t1, t2, 0);
2909
- t4 = vec_xxpermdi(t1, t2, 3);
2910
- vec_xst(t3, 0, boffset);
2911
- vec_xst(t4, 0, boffset+4);
2912
-
2913
- t1 = vec_mergel(c1[0], c2[0]);
2914
- t2 = vec_mergel(c3[0], c4[0]);
2915
- t3 = vec_xxpermdi(t1, t2, 0);
2916
- t4 = vec_xxpermdi(t1, t2, 3);
2917
- vec_xst(t3, 0, boffset+8);
2918
- vec_xst(t4, 0, boffset+12);
2343
+ for (int it = 0; it < 3; it++)
2344
+ c1[it] = vec_xl(0, aoffsets[it]);
2345
+ vector_permute_store_4(c1, boffset);
2919
2346
  }
2920
2347
  }
2921
2348
  }
@@ -2924,15 +2351,15 @@ class tinyBLAS_PPC {
2924
2351
  vec_t vec_A[4], vec_B[4], vec_C[4];
2925
2352
  acc_t acc_0;
2926
2353
  __builtin_mma_xxsetaccz(&acc_0);
2927
- for (int l = 0; l < k; l+=4) {
2928
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2929
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2354
+ for (int l = 0; l < k; l += 4) {
2355
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2356
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2930
2357
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2931
2358
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2932
2359
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2933
2360
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2934
2361
  }
2935
- SAVE_ACC(&acc_0, ii, jj);
2362
+ save_acc(&acc_0, ii, jj);
2936
2363
  }
2937
2364
 
2938
2365
  void KERNEL_4x8(int64_t ii, int64_t jj) {
@@ -2940,9 +2367,9 @@ class tinyBLAS_PPC {
2940
2367
  acc_t acc_0, acc_1;
2941
2368
  __builtin_mma_xxsetaccz(&acc_0);
2942
2369
  __builtin_mma_xxsetaccz(&acc_1);
2943
- for (int64_t l = 0; l < k; l+=4) {
2944
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2945
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
2370
+ for (int64_t l = 0; l < k; l += 4) {
2371
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2372
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
2946
2373
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2947
2374
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2948
2375
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2952,8 +2379,8 @@ class tinyBLAS_PPC {
2952
2379
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
2953
2380
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
2954
2381
  }
2955
- SAVE_ACC(&acc_0, ii, jj);
2956
- SAVE_ACC(&acc_1, ii, jj+4);
2382
+ save_acc(&acc_0, ii, jj);
2383
+ save_acc(&acc_1, ii, jj + 4);
2957
2384
  }
2958
2385
 
2959
2386
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2961,9 +2388,9 @@ class tinyBLAS_PPC {
2961
2388
  acc_t acc_0, acc_1;
2962
2389
  __builtin_mma_xxsetaccz(&acc_0);
2963
2390
  __builtin_mma_xxsetaccz(&acc_1);
2964
- for (int64_t l = 0; l < k; l+=4) {
2965
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2966
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2391
+ for (int64_t l = 0; l < k; l += 4) {
2392
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
2393
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2967
2394
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
2968
2395
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
2969
2396
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -2973,8 +2400,8 @@ class tinyBLAS_PPC {
2973
2400
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
2974
2401
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
2975
2402
  }
2976
- SAVE_ACC(&acc_0, ii, jj);
2977
- SAVE_ACC(&acc_1, ii+4, jj);
2403
+ save_acc(&acc_0, ii, jj);
2404
+ save_acc(&acc_1, ii + 4, jj);
2978
2405
  }
2979
2406
 
2980
2407
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2985,173 +2412,132 @@ class tinyBLAS_PPC {
2985
2412
  __builtin_mma_xxsetaccz(&acc_2);
2986
2413
  __builtin_mma_xxsetaccz(&acc_3);
2987
2414
  for (int l = 0; l < k; l+=8) {
2988
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
2989
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
2415
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
2416
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
2990
2417
  for(int x = 0; x < 16; x+=2) {
2991
2418
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
2992
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
2993
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
2994
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
2419
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
2420
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
2421
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
2422
+ }
2423
+ }
2424
+ save_acc(&acc_0, ii, jj);
2425
+ save_acc(&acc_1, ii, jj + 4);
2426
+ save_acc(&acc_2, ii + 4, jj);
2427
+ save_acc(&acc_3, ii + 4, jj + 4);
2428
+ }
2429
+
2430
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
2431
+ for (int x = 0; x < 16; x += 2) {
2432
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
2433
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
2434
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
2435
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
2436
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
2437
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
2438
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
2439
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
2440
+ }
2441
+ }
2442
+
2443
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
2444
+ for (int64_t i = 0; i < mc; i += 16) {
2445
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
2446
+ for (int64_t j = 0; j < nc; j += 8) {
2447
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
2448
+ acc_t acc[8];
2449
+ vec_t A0_block[16]; vec_t A1_block[16];
2450
+ for (int x = 0; x < 8; x++)
2451
+ __builtin_mma_xxsetaccz(&acc[x]);
2452
+ for (int64_t l = 0; l < kc; l += 8) {
2453
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
2454
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
2455
+ int B_block_idx = B_base_addr + (l / 8) * 16;
2456
+ vec_t* A0_block = &vec_A[A0_block_idx];
2457
+ vec_t* A1_block = &vec_A[A1_block_idx];
2458
+ vec_t* B_block = &vec_B[B_block_idx];
2459
+ MMA_16x8(A0_block, A1_block, B_block, acc);
2460
+ }
2461
+ if (kk == 0) {
2462
+ save_acc(&acc[0], ii + i, jj + j);
2463
+ save_acc(&acc[1], ii + i, jj + j + 4);
2464
+ save_acc(&acc[2], ii + i + 4, jj + j);
2465
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
2466
+ save_acc(&acc[4], ii + i + 8, jj + j);
2467
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
2468
+ save_acc(&acc[6], ii + i + 12, jj + j);
2469
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
2470
+ } else {
2471
+ add_save_acc(&acc[0], ii + i, jj + j);
2472
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
2473
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
2474
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
2475
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
2476
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
2477
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
2478
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
2479
+ }
2480
+ }
2481
+ }
2482
+ }
2483
+
2484
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2485
+ int64_t ytiles = m / mc;
2486
+ int64_t xtiles = n / nc;
2487
+ int64_t tiles = xtiles * ytiles;
2488
+ int64_t duty = (tiles + nth - 1) / nth;
2489
+ int64_t start = duty * ith;
2490
+ int64_t end = start + duty;
2491
+ if (end > tiles) {
2492
+ end = tiles;
2493
+ }
2494
+ for (int64_t job = start; job < end; ++job) {
2495
+ int64_t ii = (job / xtiles) * mc;
2496
+ int64_t jj = (job % xtiles) * nc;
2497
+ for (int64_t kk = 0; kk < k; kk += kc) {
2498
+ vec_t A_pack[kc * mc / 4];
2499
+ vec_t B_pack[kc * nc / 4];
2500
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
2501
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
2502
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
2995
2503
  }
2996
2504
  }
2997
- SAVE_ACC(&acc_0, ii, jj);
2998
- SAVE_ACC(&acc_1, ii, jj+4);
2999
- SAVE_ACC(&acc_2, ii+4, jj);
3000
- SAVE_ACC(&acc_3, ii+4, jj+4);
3001
2505
  }
3002
2506
 
3003
2507
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3004
- int64_t mc, nc, mp, np;
3005
- int m_rem = MIN(m - m0, 16);
3006
- int n_rem = MIN(n - n0, 16);
3007
- if (m_rem >= 16 && n_rem >= 8) {
3008
- mc = 8;
3009
- nc = 8;
3010
- gemm<8,8>(m0, m, n0, n);
3011
- } else if(m_rem >= 8 && n_rem >= 16) {
3012
- mc = 8;
3013
- nc = 8;
3014
- gemm<8,8>(m0, m, n0, n);
3015
- } else if (m_rem >= 8 && n_rem >= 8) {
2508
+ int m_rem = MIN(m - m0, 8);
2509
+ int n_rem = MIN(n - n0, 8);
2510
+ int mc = 0, nc = 0;
2511
+ if (m_rem >= 8 && n_rem >= 8) {
3016
2512
  mc = 8;
3017
2513
  nc = 8;
3018
- gemm<8,8>(m0, m, n0, n);
2514
+ gemm<8, 8>(m0, m, n0, n);
3019
2515
  } else if (m_rem >= 4 && n_rem >= 8) {
3020
2516
  mc = 4;
3021
2517
  nc = 8;
3022
- gemm<4,8>(m0, m, n0, n);
2518
+ gemm<4, 8>(m0, m, n0, n);
3023
2519
  } else if (m_rem >= 8 && n_rem >= 4) {
3024
2520
  mc = 8;
3025
2521
  nc = 4;
3026
- gemm<8,4>(m0, m, n0, n);
2522
+ gemm<8, 4>(m0, m, n0, n);
3027
2523
  } else if (m_rem >= 4 && n_rem >= 4) {
3028
2524
  mc = 4;
3029
2525
  nc = 4;
3030
- gemm<4,4>(m0, m, n0, n);
3031
- } else if ((m_rem < 4) && (n_rem > 4)) {
3032
- nc = 4;
3033
- switch(m_rem) {
3034
- case 1:
3035
- mc = 1;
3036
- gemm_small(m0, m, n0, n, mc, nc);
3037
- break;
3038
- case 2:
3039
- mc = 2;
3040
- gemm_small(m0, m, n0, n, mc, nc);
3041
- break;
3042
- case 3:
3043
- mc = 3;
3044
- gemm_small(m0, m, n0, n, mc, nc);
3045
- break;
3046
- default:
3047
- return;
3048
- }
3049
- } else if ((m_rem > 4) && (n_rem < 4)) {
3050
- mc = 4;
3051
- switch(n_rem) {
3052
- case 1:
3053
- nc = 1;
3054
- gemm_small(m0, m, n0, n, mc, nc);
3055
- break;
3056
- case 2:
3057
- nc = 2;
3058
- gemm_small(m0, m, n0, n, mc, nc);
3059
- break;
3060
- case 3:
3061
- nc = 3;
3062
- gemm_small(m0, m, n0, n, mc, nc);
3063
- break;
3064
- default:
3065
- return;
3066
- }
2526
+ gemm<4, 4>(m0, m, n0, n);
3067
2527
  } else {
3068
- switch((m_rem << 4) | n_rem) {
3069
- case 0x43:
3070
- mc = 4;
3071
- nc = 3;
3072
- gemm_small(m0, m, n0, n, mc, nc);
3073
- break;
3074
- case 0x42:
3075
- mc = 4;
3076
- nc = 2;
3077
- gemm_small(m0, m, n0, n, mc, nc);
3078
- break;
3079
- case 0x41:
3080
- mc = 4;
3081
- nc = 1;
3082
- gemm_small(m0, m, n0, n, mc, nc);
3083
- break;
3084
- case 0x34:
3085
- mc = 3;
3086
- nc = 4;
3087
- gemm_small(m0, m, n0, n, mc, nc);
3088
- break;
3089
- case 0x33:
3090
- mc = 3;
3091
- nc = 3;
3092
- gemm_small(m0, m, n0, n, mc, nc);
3093
- break;
3094
- case 0x32:
3095
- mc = 3;
3096
- nc = 2;
3097
- gemm_small(m0, m, n0, n, mc, nc);
3098
- break;
3099
- case 0x31:
3100
- mc = 3;
3101
- nc = 1;
3102
- gemm_small(m0, m, n0, n, mc, nc);
3103
- break;
3104
- case 0x24:
3105
- mc = 2;
3106
- nc = 4;
3107
- gemm_small(m0, m, n0, n, mc, nc);
3108
- break;
3109
- case 0x23:
3110
- mc = 2;
3111
- nc = 3;
3112
- gemm_small(m0, m, n0, n, mc, nc);
3113
- break;
3114
- case 0x22:
3115
- mc = 2;
3116
- nc = 2;
3117
- gemm_small(m0, m, n0, n, mc, nc);
3118
- break;
3119
- case 0x21:
3120
- mc = 2;
3121
- nc = 1;
3122
- gemm_small(m0, m, n0, n, mc, nc);
3123
- break;
3124
- case 0x14:
3125
- mc = 1;
3126
- nc = 4;
3127
- gemm_small(m0, m, n0, n, mc, nc);
3128
- break;
3129
- case 0x13:
3130
- mc = 1;
3131
- nc = 3;
3132
- gemm_small(m0, m, n0, n, mc, nc);
3133
- break;
3134
- case 0x12:
3135
- mc = 1;
3136
- nc = 2;
3137
- gemm_small(m0, m, n0, n, mc, nc);
3138
- break;
3139
- case 0x11:
3140
- mc = 1;
3141
- nc = 1;
3142
- gemm_small(m0, m, n0, n, mc, nc);
3143
- break;
3144
- default:
3145
- return;
3146
- }
2528
+ mc = (m_rem >= 4) ? 4 : m_rem;
2529
+ nc = (n_rem >= 4) ? 4 : n_rem;
2530
+ if (mc == 0 || nc == 0)
2531
+ return;
2532
+ gemm_small(m0, m, n0, n, mc, nc);
3147
2533
  }
3148
- mp = m0 + (m - m0) / mc * mc;
3149
- np = n0 + (n - n0) / nc * nc;
2534
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
2535
+ int64_t np = n0 + ((n - n0) / nc) * nc;
3150
2536
  mnpack(mp, m, n0, np);
3151
2537
  mnpack(m0, m, np, n);
3152
2538
  }
3153
2539
 
3154
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2540
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3155
2541
  int64_t ytiles = (m - m0) / RM;
3156
2542
  int64_t xtiles = (n - n0) / RN;
3157
2543
  int64_t tiles = xtiles * ytiles;
@@ -3166,30 +2552,30 @@ class tinyBLAS_PPC {
3166
2552
  vec_t vec_C[4];
3167
2553
  acc_t acc_0;
3168
2554
  __builtin_mma_xxsetaccz(&acc_0);
3169
- vec_t vec_A[4] {0}, vec_B[4] = {0};
3170
- for (int l=0; l<k; l+=4) {
2555
+ vec_t vec_A[4] = {0}, vec_B[4] = {0};
2556
+ for (int l = 0; l < k; l += 4) {
3171
2557
  /* 'GEMV Forwarding' concept is used in first two conditional loops.
3172
2558
  * when one of the matrix has a single row/column, the elements are
3173
2559
  * broadcasted, instead of using packing routine to prepack the
3174
2560
  * matrix elements.
3175
2561
  */
3176
2562
  if (RM == 1) {
3177
- TA* a = const_cast<TA*>(A+(ii)*lda+l);
3178
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2563
+ float * a = const_cast<float *>(A + (ii) * lda + l);
2564
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3179
2565
  vec_A[0] = (vec_t)vec_xl(0,a);
3180
- vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
3181
- vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
3182
- vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2566
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
2567
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
2568
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
3183
2569
  } else if (RN == 1) {
3184
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3185
- TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2570
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2571
+ float * b = const_cast<float *>(B + (jj) * ldb + l);
3186
2572
  vec_B[0] = (vec_t)vec_xl(0,b);
3187
- vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3188
- vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3189
- vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
2573
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
2574
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
2575
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
3190
2576
  } else {
3191
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3192
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2577
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2578
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3193
2579
  }
3194
2580
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3195
2581
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -3199,12 +2585,27 @@ class tinyBLAS_PPC {
3199
2585
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
3200
2586
  for (int I = 0; I < RM; I++) {
3201
2587
  for (int J = 0; J < RN; J++) {
3202
- *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2588
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
3203
2589
  }
3204
2590
  }
3205
2591
  }
3206
2592
  }
3207
2593
 
2594
+ template<int RM, int RN>
2595
+ inline void kernel(int64_t ii, int64_t jj) {
2596
+ if constexpr(RM == 4 && RN == 4) {
2597
+ KERNEL_4x4(ii, jj);
2598
+ } else if constexpr(RM == 4 && RN == 8) {
2599
+ KERNEL_4x8(ii, jj);
2600
+ } else if constexpr(RM == 8 && RN == 4) {
2601
+ KERNEL_8x4(ii, jj);
2602
+ } else if constexpr(RM == 8 && RN == 8) {
2603
+ KERNEL_8x8(ii, jj);
2604
+ } else {
2605
+ static_assert(false, "RN/RM values not supported");
2606
+ }
2607
+ }
2608
+
3208
2609
  template <int RM, int RN>
3209
2610
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3210
2611
  int64_t ytiles = (m - m0) / RM;
@@ -3213,29 +2614,18 @@ class tinyBLAS_PPC {
3213
2614
  int64_t duty = (tiles + nth - 1) / nth;
3214
2615
  int64_t start = duty * ith;
3215
2616
  int64_t end = start + duty;
3216
- if (RM == 4 && RN == 4) {
3217
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
3218
- } else if (RM == 4 && RN == 8) {
3219
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
3220
- } else if (RM == 8 && RN == 4) {
3221
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
3222
- } else if (RM == 8 && RN == 8) {
3223
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
3224
- }
3225
2617
  if (end > tiles)
3226
2618
  end = tiles;
3227
2619
  for (int64_t job = start; job < end; ++job) {
3228
2620
  int64_t ii = m0 + job / xtiles * RM;
3229
2621
  int64_t jj = n0 + job % xtiles * RN;
3230
- (this->*kernel)(ii, jj);
2622
+ kernel<RM, RN>(ii, jj);
3231
2623
  }
3232
2624
  }
3233
2625
 
3234
- const TA *const A;
3235
- const TB *const B;
3236
- TC *C;
3237
- TA *At;
3238
- TB *Bt;
2626
+ const float * const A;
2627
+ const float * const B;
2628
+ float * C;
3239
2629
  const int64_t k;
3240
2630
  const int64_t lda;
3241
2631
  const int64_t ldb;
@@ -3323,10 +2713,18 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3323
2713
  (const float *)B, ldb,
3324
2714
  (float *)C, ldc};
3325
2715
  return tb.matmul(m, n);
2716
+ #elif defined(__VXE__) || defined(__VXE2__)
2717
+ if (n < 4)
2718
+ return false;
2719
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
2720
+ k, (const float *)A, lda,
2721
+ (const float *)B, ldb,
2722
+ (float *)C, ldc};
2723
+ return tb.matmul(m, n);
3326
2724
  #elif defined(__MMA__)
3327
2725
  if (k % 8)
3328
2726
  return false;
3329
- tinyBLAS_PPC<float, float, float> tb{
2727
+ tinyBLAS_PPC tb{
3330
2728
  k, (const float *)A, lda,
3331
2729
  (const float *)B, ldb,
3332
2730
  (float *)C, ldc,
@@ -3414,6 +2812,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3414
2812
  (float *)C, ldc};
3415
2813
  return tb.matmul(m, n);
3416
2814
  }
2815
+ #elif defined(__VXE__) || defined(__VXE2__)
2816
+ if (n < 4)
2817
+ return false;
2818
+ if (Btype == GGML_TYPE_F16) {
2819
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
2820
+ k, (const ggml_fp16_t *)A, lda,
2821
+ (const ggml_fp16_t *)B, ldb,
2822
+ (float *)C, ldc};
2823
+ return tb.matmul(m, n);
2824
+ }
3417
2825
  #endif
3418
2826
  return false;
3419
2827
  }
@@ -3443,7 +2851,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3443
2851
  return false;
3444
2852
  if (m < 8 && m != 4)
3445
2853
  return false;
3446
- tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
2854
+ tinyBLAS_Q0_PPC<block_q8_0> tb{
3447
2855
  k, (const block_q8_0 *)A, lda,
3448
2856
  (const block_q8_0 *)B, ldb,
3449
2857
  (float *)C, ldc,
@@ -3480,7 +2888,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3480
2888
  return false;
3481
2889
  if (m < 8 && m != 4)
3482
2890
  return false;
3483
- tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
2891
+ tinyBLAS_Q0_PPC<block_q4_0> tb{
3484
2892
  k, (const block_q4_0 *)A, lda,
3485
2893
  (const block_q8_0 *)B, ldb,
3486
2894
  (float *)C, ldc,