@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -55,6 +55,7 @@
55
55
 
56
56
  #include <atomic>
57
57
  #include <array>
58
+ #include <type_traits>
58
59
 
59
60
  #ifdef _MSC_VER
60
61
  #define NOINLINE __declspec(noinline)
@@ -1053,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
1053
1054
  } \
1054
1055
  } \
1055
1056
 
1057
+ template <typename TA, typename TB, typename TC>
1058
+ class tinyBLAS_BF16_PPC {
1059
+ public:
1060
+ tinyBLAS_BF16_PPC(int64_t k,
1061
+ const TA *A, int64_t lda,
1062
+ const TB *B, int64_t ldb,
1063
+ TC *C, int64_t ldc,
1064
+ int ith, int nth)
1065
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1066
+ }
1067
+
1068
+ void matmul(int64_t m, int64_t n) {
1069
+ mnpack(0, m, 0, n);
1070
+ }
1071
+
1072
+ private:
1073
+ void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1074
+ vec_t t[8], s[8];
1075
+ vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1076
+ vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1077
+ vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1078
+ vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1079
+
1080
+ if (numVec == 2) {
1081
+ t[0] = vec_perm(c[0], c[1], swiz1);
1082
+ t[1] = vec_perm(c[2], c[3], swiz1);
1083
+ s[0] = vec_perm(t[0], t[1], swiz3);
1084
+ s[1] = vec_perm(t[0], t[1], swiz4);
1085
+ vec_xst(s[0], 0, (vec_t*)vecOffset);
1086
+ vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1087
+ } else if (numVec == 4) {
1088
+ t[0] = vec_perm(c[0], c[1], swiz1);
1089
+ t[1] = vec_perm(c[0], c[1], swiz2);
1090
+ t[2] = vec_perm(c[2], c[3], swiz1);
1091
+ t[3] = vec_perm(c[2], c[3], swiz2);
1092
+ s[0] = vec_perm(t[0], t[2], swiz3);
1093
+ s[1] = vec_perm(t[0], t[2], swiz4);
1094
+ s[2] = vec_perm(t[1], t[3], swiz3);
1095
+ s[3] = vec_perm(t[1], t[3], swiz4);
1096
+ for (int i = 0; i < 4; ++i)
1097
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1098
+ } else if (numVec == 8) {
1099
+ for (int i = 0; i < 4; i += 2) {
1100
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1101
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1102
+ }
1103
+ for (int i = 4; i < 8; i += 2) {
1104
+ t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1105
+ t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1106
+ }
1107
+ s[0] = vec_perm(t[0], t[2], swiz3);
1108
+ s[1] = vec_perm(t[0], t[2], swiz4);
1109
+ s[2] = vec_perm(t[1], t[3], swiz3);
1110
+ s[3] = vec_perm(t[1], t[3], swiz4);
1111
+ s[4] = vec_perm(t[4], t[6], swiz3);
1112
+ s[5] = vec_perm(t[4], t[6], swiz4);
1113
+ s[6] = vec_perm(t[5], t[7], swiz3);
1114
+ s[7] = vec_perm(t[5], t[7], swiz4);
1115
+ for (int i = 0; i < 8; ++i)
1116
+ vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1117
+ }
1118
+ }
1119
+
1120
+ void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1121
+ int64_t i, j;
1122
+ TA *aoffset = NULL;
1123
+ unsigned char *vecOffset = NULL;
1124
+ TA * aoffsets[8];
1125
+ vector unsigned char c_arr[8];
1126
+ aoffset = const_cast<TA*>(a);
1127
+ vecOffset = vec;
1128
+ j = (rows >> 3);
1129
+ if (j > 0) {
1130
+ do {
1131
+ if (cols == 4) {
1132
+ aoffsets[0] = aoffset;
1133
+ for (int it = 1; it < 4; ++it)
1134
+ aoffsets[it] = aoffsets[it-1] + lda;
1135
+ aoffset += 4 * lda;
1136
+ for (int i = 0; i < 4; ++i)
1137
+ c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1138
+ vector_permute_store(c_arr, 4, vecOffset);
1139
+ for (int i = 0; i<4; i++)
1140
+ aoffsets[i] = aoffsets[i]+lda;
1141
+ vecOffset +=64;
1142
+ }
1143
+ i = (cols >> 3);
1144
+ if (i > 0) {
1145
+ aoffsets[0] = aoffset;
1146
+ for (int it = 1; it < 8; ++it) {
1147
+ aoffsets[it] = aoffsets[it-1] + lda;
1148
+ }
1149
+ aoffset += 8 * lda;
1150
+ do {
1151
+ for (int it = 0; it < 8; ++it)
1152
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1153
+ vector_permute_store(c_arr, 8, vecOffset);
1154
+ for (int it = 0; it < 8; ++it)
1155
+ aoffsets[it] = aoffsets[it] + 8*lda;
1156
+ vecOffset += 128;
1157
+ i--;
1158
+ } while(i > 0);
1159
+ }
1160
+ j--;
1161
+ } while(j > 0);
1162
+ }
1163
+ if (rows & 4) {
1164
+ aoffsets[0] = aoffset;
1165
+ for (int it = 1; it < 4; ++it)
1166
+ aoffsets[it] = aoffsets[it-1] + lda;
1167
+ aoffset += 4 * lda;
1168
+ if (cols == 4) {
1169
+ for (int it = 0; it < 4; ++it)
1170
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1171
+ vector_permute_store(c_arr, 2, vecOffset);
1172
+ for (int it = 0; it< 4; it++)
1173
+ aoffsets[it] = aoffsets[it] + lda;
1174
+ vecOffset += 32;
1175
+ }
1176
+ i = (cols >> 3);
1177
+ if (i > 0) {
1178
+ do {
1179
+ for (int it = 0; it < 4; ++it)
1180
+ c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1181
+ vector_permute_store(c_arr, 4, vecOffset);
1182
+ for (int it = 0; it< 4; it++)
1183
+ aoffsets[it] = aoffsets[it] + 8*lda;
1184
+ vecOffset += 64;
1185
+ i--;
1186
+ } while(i > 0);
1187
+ }
1188
+ }
1189
+ if (rows & 3) {
1190
+ aoffsets[0] = aoffset;
1191
+ for (int it = 1; it < 4; ++it)
1192
+ aoffsets[it] = aoffsets[it-1] + lda;
1193
+ if (cols == 4) {
1194
+ switch(rows) {
1195
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1196
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1197
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1198
+ break;
1199
+ }
1200
+ vector_permute_store(c_arr, 2, vecOffset);
1201
+ for (int it = 0; it< 4; it++)
1202
+ aoffsets[it] = aoffsets[it] + lda;
1203
+ vecOffset += 32;
1204
+ }
1205
+ i = (cols >> 3);
1206
+ if (i > 0) {
1207
+ do {
1208
+ switch(rows) {
1209
+ case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1210
+ case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1211
+ case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1212
+ break;
1213
+ }
1214
+ vector_permute_store(c_arr, 4, vecOffset);
1215
+ for (int it = 0; it <4; it++)
1216
+ aoffsets[it] = aoffsets[it] + 8* lda;
1217
+ vecOffset += 64;
1218
+ i--;
1219
+ } while(i > 0);
1220
+ }
1221
+ }
1222
+ }
1223
+
1224
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1225
+ int64_t mc, nc, mp, np;
1226
+ int m_rem = MIN(m - m0, 8);
1227
+ int n_rem = MIN(n - n0, 8);
1228
+
1229
+ if (m_rem >= 8 && n_rem >= 8) {
1230
+ mc = 8;
1231
+ nc = 8;
1232
+ gemm<8,8>(m0, m, n0, n);
1233
+ } else if (m_rem >= 4 && n_rem >= 8) {
1234
+ mc = 4;
1235
+ nc = 8;
1236
+ gemm<4,8>(m0, m, n0, n);
1237
+ } else if (m_rem >=8 && n_rem >=4){
1238
+ mc = 8;
1239
+ nc = 4;
1240
+ gemm<8,4>(m0, m, n0, n);
1241
+ } else if ((m_rem < 4) && (n_rem >= 8)) {
1242
+ nc = 8;
1243
+ switch(m_rem) {
1244
+ case 1:
1245
+ mc = 1;
1246
+ gemm_Mx8<1>(m0, m, n0, n);
1247
+ break;
1248
+ case 2:
1249
+ mc = 2;
1250
+ gemm_Mx8<2>(m0, m, n0, n);
1251
+ break;
1252
+ case 3:
1253
+ mc = 3;
1254
+ gemm_Mx8<3>(m0, m, n0, n);
1255
+ break;
1256
+ default:
1257
+ return;
1258
+ }
1259
+ } else if (m_rem >= 4 && n_rem >= 4) {
1260
+ mc = 4;
1261
+ nc = 4;
1262
+ gemm_small<4, 4>(m0, m, n0, n);
1263
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1264
+ mc = 4;
1265
+ switch(n_rem) {
1266
+ case 1:
1267
+ nc = 1;
1268
+ gemm_small<4, 1>(m0, m, n0, n);
1269
+ break;
1270
+ case 2:
1271
+ nc = 2;
1272
+ gemm_small<4, 2>(m0, m, n0, n);
1273
+ break;
1274
+ case 3:
1275
+ nc = 3;
1276
+ gemm_small<4, 3>(m0, m, n0, n);
1277
+ break;
1278
+
1279
+ default:
1280
+ return;
1281
+ }
1282
+ } else {
1283
+ switch((m_rem << 4) | n_rem) {
1284
+ case 0x43:
1285
+ mc = 4;
1286
+ nc = 3;
1287
+ gemm_small<4, 3>(m0, m, n0, n);
1288
+ break;
1289
+ case 0x42:
1290
+ mc = 4;
1291
+ nc = 2;
1292
+ gemm_small<4, 2>(m0, m, n0, n);
1293
+ break;
1294
+ case 0x41:
1295
+ mc = 4;
1296
+ nc = 1;
1297
+ gemm_small<4, 1>(m0, m, n0, n);
1298
+ break;
1299
+ case 0x34:
1300
+ mc = 3;
1301
+ nc = 4;
1302
+ gemm_small<3, 4>(m0, m, n0, n);
1303
+ break;
1304
+ case 0x33:
1305
+ mc = 3;
1306
+ nc = 3;
1307
+ gemm_small<3, 3>(m0, m, n0, n);
1308
+ break;
1309
+ case 0x32:
1310
+ mc = 3;
1311
+ nc = 2;
1312
+ gemm_small<3, 2>(m0, m, n0, n);
1313
+ break;
1314
+ case 0x31:
1315
+ mc = 3;
1316
+ nc = 1;
1317
+ gemm_small<3, 1>(m0, m, n0, n);
1318
+ break;
1319
+ case 0x24:
1320
+ mc = 2;
1321
+ nc = 4;
1322
+ gemm_small<2,4>(m0, m, n0, n);
1323
+ break;
1324
+ case 0x23:
1325
+ mc = 2;
1326
+ nc = 3;
1327
+ gemm_small<2, 3>(m0, m, n0, n);
1328
+ break;
1329
+ case 0x22:
1330
+ mc = 2;
1331
+ nc = 2;
1332
+ gemm_small<2, 2>(m0, m, n0, n);
1333
+ break;
1334
+ case 0x21:
1335
+ mc = 2;
1336
+ nc = 1;
1337
+ gemm_small<2, 1>(m0, m, n0, n);
1338
+ break;
1339
+ case 0x14:
1340
+ mc = 1;
1341
+ nc = 4;
1342
+ gemm_small<1, 4>(m0, m, n0, n);
1343
+ break;
1344
+ case 0x13:
1345
+ mc = 1;
1346
+ nc = 3;
1347
+ gemm_small<1, 3>(m0, m, n0, n);
1348
+ break;
1349
+ case 0x12:
1350
+ mc = 1;
1351
+ nc = 2;
1352
+ gemm_small<1, 2>(m0, m, n0, n);
1353
+ break;
1354
+ case 0x11:
1355
+ mc = 1;
1356
+ nc = 1;
1357
+ gemm_small<1, 1>(m0, m, n0, n);
1358
+ break;
1359
+ default:
1360
+ return;
1361
+ }
1362
+ }
1363
+ mp = m0 + (m - m0) / mc * mc;
1364
+ np = n0 + (n - n0) / nc * nc;
1365
+ mnpack(mp, m, n0, np);
1366
+ mnpack(m0, m, np, n);
1367
+ }
1368
+
1369
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1370
+ vec_t vec_A[4], vec_B[8] , vec_C[4];
1371
+ acc_t acc_0, acc_1;
1372
+ __builtin_mma_xxsetaccz(&acc_0);
1373
+ __builtin_mma_xxsetaccz(&acc_1);
1374
+ for (int l = 0; l < k; l+=8) {
1375
+ packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
1376
+ packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
1377
+ for (int x = 0; x < 4; x++) {
1378
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1379
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1380
+ }
1381
+ }
1382
+ SAVE_ACC(&acc_0, ii, jj);
1383
+ SAVE_ACC(&acc_1, ii, jj+4);
1384
+ }
1385
+
1386
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1387
+ vec_t vec_A[8], vec_B[4] , vec_C[4];
1388
+ acc_t acc_0, acc_1;
1389
+ __builtin_mma_xxsetaccz(&acc_0);
1390
+ __builtin_mma_xxsetaccz(&acc_1);
1391
+ for (int l = 0; l < k; l+=8) {
1392
+ packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
1393
+ packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
1394
+ for (int x = 0; x < 4; x++) {
1395
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1396
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
1397
+ }
1398
+ }
1399
+ SAVE_ACC(&acc_0, ii, jj);
1400
+ SAVE_ACC(&acc_1, ii+4, jj);
1401
+ }
1402
+
1403
+
1404
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1405
+ vec_t vec_A[8], vec_B[8], vec_C[4];
1406
+ acc_t acc_0, acc_1, acc_2, acc_3;
1407
+ __builtin_mma_xxsetaccz(&acc_0);
1408
+ __builtin_mma_xxsetaccz(&acc_1);
1409
+ __builtin_mma_xxsetaccz(&acc_2);
1410
+ __builtin_mma_xxsetaccz(&acc_3);
1411
+ for (int l = 0; l < k; l+=8) {
1412
+ packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
1413
+ packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
1414
+ for (int x = 0; x < 4; x++) {
1415
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1416
+ __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
1417
+ __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
1418
+ __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
1419
+ }
1420
+ }
1421
+
1422
+ SAVE_ACC(&acc_0, ii, jj);
1423
+ SAVE_ACC(&acc_1, ii, jj+4);
1424
+ SAVE_ACC(&acc_2, ii+4, jj);
1425
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1426
+ }
1427
+
1428
+ template<int RM, int RN>
1429
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1430
+ int64_t ytiles = (m - m0) / RM;
1431
+ int64_t xtiles = (n - n0) / RN;
1432
+ int64_t tiles = xtiles * ytiles;
1433
+ int64_t duty = (tiles + nth - 1) / nth;
1434
+ int64_t start = duty * ith;
1435
+ int64_t end = start + duty;
1436
+ if (end > tiles)
1437
+ end = tiles;
1438
+ for (int64_t job = start; job < end; ++job) {
1439
+ int64_t ii = m0 + job / xtiles * RM;
1440
+ int64_t jj = n0 + job % xtiles * RN;
1441
+ vec_t vec_C[4];
1442
+ acc_t acc_0;
1443
+ __builtin_mma_xxsetaccz(&acc_0);
1444
+ vec_t vec_A[2], vec_B[2];
1445
+ for (int l=0; l<k; l+=4) {
1446
+ packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
1447
+ packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
1448
+ for (int x = 0; x<2; x++) {
1449
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1450
+ }
1451
+ }
1452
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1453
+ for (int I = 0; I < RM; I++) {
1454
+ for (int J = 0; J < RN; J++) {
1455
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1456
+ }
1457
+ }
1458
+ }
1459
+ }
1460
+
1461
+ template<int RM>
1462
+ void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1463
+ int RN = 8;
1464
+ int64_t ytiles = (m - m0) / RM;
1465
+ int64_t xtiles = (n - n0) / RN;
1466
+ int64_t tiles = xtiles * ytiles;
1467
+ int64_t duty = (tiles + nth - 1) / nth;
1468
+ int64_t start = duty * ith;
1469
+ int64_t end = start + duty;
1470
+ if (end > tiles)
1471
+ end = tiles;
1472
+ for (int64_t job = start; job < end; ++job) {
1473
+ int64_t ii = m0 + job / xtiles * RM;
1474
+ int64_t jj = n0 + job % xtiles * RN;
1475
+ vec_t vec_C[4];
1476
+ acc_t acc_0, acc_1;
1477
+ __builtin_mma_xxsetaccz(&acc_0);
1478
+ __builtin_mma_xxsetaccz(&acc_1);
1479
+ vec_t vec_A[4], vec_B[8];
1480
+ for (int l=0; l<k; l+=8) {
1481
+ packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
1482
+ packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
1483
+ for (int x = 0; x<4; x++) {
1484
+ __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1485
+ __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1486
+ }
1487
+ }
1488
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1489
+ for (int I = 0; I < RM; I++) {
1490
+ for (int J = 0; J < 4; J++) {
1491
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1492
+ }
1493
+ }
1494
+ __builtin_mma_disassemble_acc(vec_C, &acc_1);
1495
+ for (int I = 0; I < RM; I++) {
1496
+ for (int J = 0; J < 4; J++) {
1497
+ *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1498
+ }
1499
+ }
1500
+ }
1501
+ }
1502
+
1503
+ template<int RM, int RN>
1504
+ inline void kernel(int64_t ii, int64_t jj) {
1505
+ if constexpr(RM == 4 && RN == 8) {
1506
+ KERNEL_4x8(ii,jj);
1507
+ } else if constexpr(RM == 8 && RN == 8) {
1508
+ KERNEL_8x8(ii,jj);
1509
+ } else if constexpr(RM == 8 && RN == 4) {
1510
+ KERNEL_8x4(ii,jj);
1511
+ } else {
1512
+ static_assert(false, "RN/RM values not supported");
1513
+ }
1514
+ }
1515
+
1516
+ template <int RM, int RN>
1517
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1518
+ int64_t ytiles = (m - m0) / RM;
1519
+ int64_t xtiles = (n - n0) / RN;
1520
+ int64_t tiles = xtiles * ytiles;
1521
+ int64_t duty = (tiles + nth - 1) / nth;
1522
+ int64_t start = duty * ith;
1523
+ int64_t end = start + duty;
1524
+ if (end > tiles)
1525
+ end = tiles;
1526
+ for (int64_t job = start; job < end; ++job) {
1527
+ int64_t ii = m0 + job / xtiles * RM;
1528
+ int64_t jj = n0 + job % xtiles * RN;
1529
+ kernel<RM, RN>(ii, jj);
1530
+ }
1531
+ }
1532
+
1533
+ const TA *const A;
1534
+ const TB *const B;
1535
+ TC *C;
1536
+ const int64_t k;
1537
+ const int64_t lda;
1538
+ const int64_t ldb;
1539
+ const int64_t ldc;
1540
+ const int ith;
1541
+ const int nth;
1542
+ };
1543
+
1056
1544
  template <typename TA, typename TB, typename TC>
1057
1545
  class tinyBLAS_Q0_PPC {
1058
1546
  public:
@@ -1092,13 +1580,403 @@ class tinyBLAS_Q0_PPC {
1092
1580
  }
1093
1581
  }
1094
1582
 
1095
- template<typename VA, typename VB>
1096
- void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1583
+ template<typename VA, typename VB, int size>
1584
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
1097
1585
  int64_t i, j;
1098
1586
  TA *aoffset = NULL;
1099
1587
  VA *vecOffset = NULL;
1100
1588
  TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1101
1589
  TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1590
+ VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1591
+ VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1592
+ VB t1, t2, t3, t4, t5, t6, t7, t8;
1593
+ const vector signed char lowMask = vec_splats((signed char)0xF);
1594
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1595
+ const vector signed char v8 = vec_splats((signed char)0x8);
1596
+ aoffset = const_cast<TA*>(a);
1597
+ vecOffset = vec;
1598
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1599
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1600
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1601
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1602
+ vector signed int vsum = {0};
1603
+ vector signed int vsum2 = {0};
1604
+
1605
+ j = (rows >> 3);
1606
+ if (j > 0) {
1607
+ do {
1608
+ aoffset1 = aoffset;
1609
+ aoffset2 = aoffset1 + lda;
1610
+ aoffset3 = aoffset2 + lda;
1611
+ aoffset4 = aoffset3 + lda;
1612
+ aoffset5 = aoffset4 + lda;
1613
+ aoffset6 = aoffset5 + lda;
1614
+ aoffset7 = aoffset6 + lda;
1615
+ aoffset8 = aoffset7 + lda;
1616
+ aoffset += 8 * lda;
1617
+
1618
+ i = (cols >> 2);
1619
+ if (i > 0) {
1620
+ do {
1621
+ c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1622
+ c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1623
+ c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1624
+ c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1625
+ c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1626
+ c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1627
+ c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1628
+ c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1629
+
1630
+ c1[0] = vec_and(c1[1], lowMask);
1631
+ c1[1] = vec_sr(c1[1], v4);
1632
+ c1[0] = vec_sub(c1[0], v8);
1633
+ c1[1] = vec_sub(c1[1], v8);
1634
+ vsum = vec_sum4s(c1[0], vsum);
1635
+ vsum2 = vec_sum4s(c1[1], vsum2);
1636
+ vsum = vec_add(vsum, vsum2);
1637
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1638
+ vsum = vec_splats(0);
1639
+ vsum2 = vec_splats(0);
1640
+
1641
+ c2[0] = vec_and(c2[1], lowMask);
1642
+ c2[1] = vec_sr(c2[1], v4);
1643
+ c2[0] = vec_sub(c2[0], v8);
1644
+ c2[1] = vec_sub(c2[1], v8);
1645
+ vsum = vec_sum4s(c2[0], vsum);
1646
+ vsum2 = vec_sum4s(c2[1], vsum2);
1647
+ vsum = vec_add(vsum, vsum2);
1648
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1649
+ vsum = vec_splats(0);
1650
+ vsum2 = vec_splats(0);
1651
+
1652
+ c3[0] = vec_and(c3[1], lowMask);
1653
+ c3[1] = vec_sr(c3[1], v4);
1654
+ c3[0] = vec_sub(c3[0], v8);
1655
+ c3[1] = vec_sub(c3[1], v8);
1656
+ vsum = vec_sum4s(c3[0], vsum);
1657
+ vsum2 = vec_sum4s(c3[1], vsum2);
1658
+ vsum = vec_add(vsum, vsum2);
1659
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1660
+ vsum = vec_splats(0);
1661
+ vsum2 = vec_splats(0);
1662
+
1663
+ c4[0] = vec_and(c4[1], lowMask);
1664
+ c4[1] = vec_sr(c4[1], v4);
1665
+ c4[0] = vec_sub(c4[0], v8);
1666
+ c4[1] = vec_sub(c4[1], v8);
1667
+ vsum = vec_sum4s(c4[0], vsum);
1668
+ vsum2 = vec_sum4s(c4[1], vsum2);
1669
+ vsum = vec_add(vsum, vsum2);
1670
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1671
+ vsum = vec_splats(0);
1672
+ vsum2 = vec_splats(0);
1673
+
1674
+ c5[0] = vec_and(c5[1], lowMask);
1675
+ c5[1] = vec_sr(c5[1], v4);
1676
+ c5[0] = vec_sub(c5[0], v8);
1677
+ c5[1] = vec_sub(c5[1], v8);
1678
+ vsum = vec_sum4s(c5[0], vsum);
1679
+ vsum2 = vec_sum4s(c5[1], vsum2);
1680
+ vsum = vec_add(vsum, vsum2);
1681
+ comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1682
+ vsum = vec_splats(0);
1683
+ vsum2 = vec_splats(0);
1684
+
1685
+ c6[0] = vec_and(c6[1], lowMask);
1686
+ c6[1] = vec_sr(c6[1], v4);
1687
+ c6[0] = vec_sub(c6[0], v8);
1688
+ c6[1] = vec_sub(c6[1], v8);
1689
+ vsum = vec_sum4s(c6[0], vsum);
1690
+ vsum2 = vec_sum4s(c6[1], vsum2);
1691
+ vsum = vec_add(vsum, vsum2);
1692
+ comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1693
+ vsum = vec_splats(0);
1694
+ vsum2 = vec_splats(0);
1695
+
1696
+ c7[0] = vec_and(c7[1], lowMask);
1697
+ c7[1] = vec_sr(c7[1], v4);
1698
+ c7[0] = vec_sub(c7[0], v8);
1699
+ c7[1] = vec_sub(c7[1], v8);
1700
+ vsum = vec_sum4s(c7[0], vsum);
1701
+ vsum2 = vec_sum4s(c7[1], vsum2);
1702
+ vsum = vec_add(vsum, vsum2);
1703
+ comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1704
+ vsum = vec_splats(0);
1705
+ vsum2 = vec_splats(0);
1706
+
1707
+ c8[0] = vec_and(c8[1], lowMask);
1708
+ c8[1] = vec_sr(c8[1], v4);
1709
+ c8[0] = vec_sub(c8[0], v8);
1710
+ c8[1] = vec_sub(c8[1], v8);
1711
+ vsum = vec_sum4s(c8[0], vsum);
1712
+ vsum2 = vec_sum4s(c8[1], vsum2);
1713
+ vsum = vec_add(vsum, vsum2);
1714
+ comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1715
+ vsum = vec_splats(0);
1716
+ vsum2 = vec_splats(0);
1717
+
1718
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1719
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1720
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1721
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1722
+ t5 = vec_perm(t1, t3, swiz3);
1723
+ t6 = vec_perm(t1, t3, swiz4);
1724
+ t7 = vec_perm(t2, t4, swiz3);
1725
+ t8 = vec_perm(t2, t4, swiz4);
1726
+ vec_xst(t5, 0, vecOffset);
1727
+ vec_xst(t6, 0, vecOffset+16);
1728
+ vec_xst(t7, 0, vecOffset+32);
1729
+ vec_xst(t8, 0, vecOffset+48);
1730
+
1731
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1732
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1733
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1734
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1735
+ t5 = vec_perm(t1, t3, swiz3);
1736
+ t6 = vec_perm(t1, t3, swiz4);
1737
+ t7 = vec_perm(t2, t4, swiz3);
1738
+ t8 = vec_perm(t2, t4, swiz4);
1739
+ vec_xst(t5, 0, vecOffset+64);
1740
+ vec_xst(t6, 0, vecOffset+80);
1741
+ vec_xst(t7, 0, vecOffset+96);
1742
+ vec_xst(t8, 0, vecOffset+112);
1743
+
1744
+ t1 = vec_perm(c5[0], c6[0], swiz1);
1745
+ t2 = vec_perm(c5[0], c6[0], swiz2);
1746
+ t3 = vec_perm(c7[0], c8[0], swiz1);
1747
+ t4 = vec_perm(c7[0], c8[0], swiz2);
1748
+ t5 = vec_perm(t1, t3, swiz3);
1749
+ t6 = vec_perm(t1, t3, swiz4);
1750
+ t7 = vec_perm(t2, t4, swiz3);
1751
+ t8 = vec_perm(t2, t4, swiz4);
1752
+ vec_xst(t5, 0, vecOffset+128);
1753
+ vec_xst(t6, 0, vecOffset+144);
1754
+ vec_xst(t7, 0, vecOffset+160);
1755
+ vec_xst(t8, 0, vecOffset+176);
1756
+
1757
+ t1 = vec_perm(c5[1], c6[1], swiz1);
1758
+ t2 = vec_perm(c5[1], c6[1], swiz2);
1759
+ t3 = vec_perm(c7[1], c8[1], swiz1);
1760
+ t4 = vec_perm(c7[1], c8[1], swiz2);
1761
+ t5 = vec_perm(t1, t3, swiz3);
1762
+ t6 = vec_perm(t1, t3, swiz4);
1763
+ t7 = vec_perm(t2, t4, swiz3);
1764
+ t8 = vec_perm(t2, t4, swiz4);
1765
+ vec_xst(t5, 0, vecOffset+192);
1766
+ vec_xst(t6, 0, vecOffset+208);
1767
+ vec_xst(t7, 0, vecOffset+224);
1768
+ vec_xst(t8, 0, vecOffset+240);
1769
+
1770
+ aoffset1 += lda;
1771
+ aoffset2 += lda;
1772
+ aoffset3 += lda;
1773
+ aoffset4 += lda;
1774
+ aoffset5 += lda;
1775
+ aoffset6 += lda;
1776
+ aoffset7 += lda;
1777
+ aoffset8 += lda;
1778
+ vecOffset += 256;
1779
+ i--;
1780
+ } while (i > 0);
1781
+ }
1782
+ j--;
1783
+ } while (j > 0);
1784
+ }
1785
+
1786
+ if (rows & 4) {
1787
+ aoffset1 = aoffset;
1788
+ aoffset2 = aoffset1 + lda;
1789
+ aoffset3 = aoffset2 + lda;
1790
+ aoffset4 = aoffset3 + lda;
1791
+ aoffset += 4 * lda;
1792
+
1793
+ i = (cols >> 2);
1794
+ if (i > 0) {
1795
+ do {
1796
+ c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1797
+ c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1798
+ c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1799
+ c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1800
+
1801
+ c1[0] = vec_and(c1[1], lowMask);
1802
+ c1[1] = vec_sr(c1[1], v4);
1803
+ c1[0] = vec_sub(c1[0], v8);
1804
+ c1[1] = vec_sub(c1[1], v8);
1805
+ vsum = vec_sum4s(c1[0], vsum);
1806
+ vsum2 = vec_sum4s(c1[1], vsum2);
1807
+ vsum = vec_add(vsum, vsum2);
1808
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1809
+ vsum = vec_splats(0);
1810
+ vsum2 = vec_splats(0);
1811
+
1812
+ c2[0] = vec_and(c2[1], lowMask);
1813
+ c2[1] = vec_sr(c2[1], v4);
1814
+ c2[0] = vec_sub(c2[0], v8);
1815
+ c2[1] = vec_sub(c2[1], v8);
1816
+ vsum = vec_sum4s(c2[0], vsum);
1817
+ vsum2 = vec_sum4s(c2[1], vsum2);
1818
+ vsum = vec_add(vsum, vsum2);
1819
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1820
+ vsum = vec_splats(0);
1821
+ vsum2 = vec_splats(0);
1822
+
1823
+ c3[0] = vec_and(c3[1], lowMask);
1824
+ c3[1] = vec_sr(c3[1], v4);
1825
+ c3[0] = vec_sub(c3[0], v8);
1826
+ c3[1] = vec_sub(c3[1], v8);
1827
+ vsum = vec_sum4s(c3[0], vsum);
1828
+ vsum2 = vec_sum4s(c3[1], vsum2);
1829
+ vsum = vec_add(vsum, vsum2);
1830
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1831
+ vsum = vec_splats(0);
1832
+ vsum2 = vec_splats(0);
1833
+
1834
+ c4[0] = vec_and(c4[1], lowMask);
1835
+ c4[1] = vec_sr(c4[1], v4);
1836
+ c4[0] = vec_sub(c4[0], v8);
1837
+ c4[1] = vec_sub(c4[1], v8);
1838
+ vsum = vec_sum4s(c4[0], vsum);
1839
+ vsum2 = vec_sum4s(c4[1], vsum2);
1840
+ vsum = vec_add(vsum, vsum2);
1841
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1842
+ vsum = vec_splats(0);
1843
+ vsum2 = vec_splats( 0);
1844
+
1845
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1846
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1847
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1848
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1849
+ t5 = vec_perm(t1, t3, swiz3);
1850
+ t6 = vec_perm(t1, t3, swiz4);
1851
+ t7 = vec_perm(t2, t4, swiz3);
1852
+ t8 = vec_perm(t2, t4, swiz4);
1853
+ vec_xst(t5, 0, vecOffset);
1854
+ vec_xst(t6, 0, vecOffset+16);
1855
+ vec_xst(t7, 0, vecOffset+32);
1856
+ vec_xst(t8, 0, vecOffset+48);
1857
+
1858
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1859
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1860
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1861
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1862
+ t5 = vec_perm(t1, t3, swiz3);
1863
+ t6 = vec_perm(t1, t3, swiz4);
1864
+ t7 = vec_perm(t2, t4, swiz3);
1865
+ t8 = vec_perm(t2, t4, swiz4);
1866
+ vec_xst(t5, 0, vecOffset+64);
1867
+ vec_xst(t6, 0, vecOffset+80);
1868
+ vec_xst(t7, 0, vecOffset+96);
1869
+ vec_xst(t8, 0, vecOffset+112);
1870
+
1871
+ aoffset1 += lda;
1872
+ aoffset2 += lda;
1873
+ aoffset3 += lda;
1874
+ aoffset4 += lda;
1875
+ vecOffset += 128;
1876
+ i--;
1877
+ } while (i > 0);
1878
+ }
1879
+ }
1880
+
1881
+ if (rows & 3) {
1882
+ aoffset1 = aoffset;
1883
+ aoffset2 = aoffset1 + lda;
1884
+ aoffset3 = aoffset2 + lda;
1885
+ i = (cols >> 2);
1886
+ if (i > 0) {
1887
+ do {
1888
+ switch(rows) {
1889
+ case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1890
+ case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1891
+ case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1892
+ break;
1893
+ }
1894
+ c1[0] = vec_and(c1[1], lowMask);
1895
+ c1[1] = vec_sr(c1[1], v4);
1896
+ c1[0] = vec_sub(c1[0], v8);
1897
+ c1[1] = vec_sub(c1[1], v8);
1898
+ vsum = vec_sum4s(c1[0], vsum);
1899
+ vsum2 = vec_sum4s(c1[1], vsum2);
1900
+ vsum = vec_add(vsum, vsum2);
1901
+ comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1902
+ vsum = vec_splats(0);
1903
+ vsum2 = vec_splats(0);
1904
+
1905
+ c2[0] = vec_and(c2[1], lowMask);
1906
+ c2[1] = vec_sr(c2[1], v4);
1907
+ c2[0] = vec_sub(c2[0], v8);
1908
+ c2[1] = vec_sub(c2[1], v8);
1909
+ vsum = vec_sum4s(c2[0], vsum);
1910
+ vsum2 = vec_sum4s(c2[1], vsum2);
1911
+ vsum = vec_add(vsum, vsum2);
1912
+ comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1913
+ vsum = vec_splats(0);
1914
+ vsum2 = vec_splats(0);
1915
+
1916
+ c3[0] = vec_and(c3[1], lowMask);
1917
+ c3[1] = vec_sr(c3[1], v4);
1918
+ c3[0] = vec_sub(c3[0], v8);
1919
+ c3[1] = vec_sub(c3[1], v8);
1920
+ vsum = vec_sum4s(c3[0], vsum);
1921
+ vsum2 = vec_sum4s(c3[1], vsum2);
1922
+ vsum = vec_add(vsum, vsum2);
1923
+ comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1924
+ vsum = vec_splats(0);
1925
+ vsum2 = vec_splats(0);
1926
+
1927
+ c4[0] = vec_and(c4[1], lowMask);
1928
+ c4[1] = vec_sr(c4[1], v4);
1929
+ c4[0] = vec_sub(c4[0], v8);
1930
+ c4[1] = vec_sub(c4[1], v8);
1931
+ vsum = vec_sum4s(c4[0], vsum);
1932
+ vsum2 = vec_sum4s(c4[1], vsum2);
1933
+ vsum = vec_add(vsum, vsum2);
1934
+ comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1935
+ vsum = vec_splats(0);
1936
+ vsum2 = vec_splats(0);
1937
+
1938
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1939
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1940
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1941
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1942
+ t5 = vec_perm(t1, t3, swiz3);
1943
+ t6 = vec_perm(t1, t3, swiz4);
1944
+ t7 = vec_perm(t2, t4, swiz3);
1945
+ t8 = vec_perm(t2, t4, swiz4);
1946
+ vec_xst(t5, 0, vecOffset);
1947
+ vec_xst(t6, 0, vecOffset+16);
1948
+ vec_xst(t7, 0, vecOffset+32);
1949
+ vec_xst(t8, 0, vecOffset+48);
1950
+
1951
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1952
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1953
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1954
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1955
+ t5 = vec_perm(t1, t3, swiz3);
1956
+ t6 = vec_perm(t1, t3, swiz4);
1957
+ t7 = vec_perm(t2, t4, swiz3);
1958
+ t8 = vec_perm(t2, t4, swiz4);
1959
+ vec_xst(t5, 0, vecOffset+64);
1960
+ vec_xst(t6, 0, vecOffset+80);
1961
+ vec_xst(t7, 0, vecOffset+96);
1962
+ vec_xst(t8, 0, vecOffset+112);
1963
+ aoffset1 += lda;
1964
+ aoffset2 += lda;
1965
+ aoffset3 += lda;
1966
+ vecOffset += 128;
1967
+ i--;
1968
+ } while(i > 0);
1969
+ }
1970
+ }
1971
+ }
1972
+
1973
+ template<typename VA, typename VB>
1974
+ void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1975
+ int64_t i, j;
1976
+ TB *aoffset = NULL;
1977
+ VA *vecOffset = NULL;
1978
+ TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1979
+ TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1102
1980
  __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1103
1981
  VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1104
1982
  VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
@@ -1111,24 +1989,24 @@ class tinyBLAS_Q0_PPC {
1111
1989
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1112
1990
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1113
1991
 
1114
- aoffset = const_cast<TA*>(a);
1992
+ aoffset = const_cast<TB*>(a);
1115
1993
  vecOffset = vec;
1116
1994
  j = (rows >> 3);
1117
1995
  if (j > 0) {
1118
1996
  do {
1119
- aoffset1 = aoffset;
1120
- aoffset2 = aoffset1 + lda;
1121
- aoffset3 = aoffset2 + lda;
1122
- aoffset4 = aoffset3 + lda;
1123
- aoffset5 = aoffset4 + lda;
1124
- aoffset6 = aoffset5 + lda;
1125
- aoffset7 = aoffset6 + lda;
1126
- aoffset8 = aoffset7 + lda;
1127
- aoffset += 8 * lda;
1997
+ aoffset1 = aoffset;
1998
+ aoffset2 = aoffset1 + lda;
1999
+ aoffset3 = aoffset2 + lda;
2000
+ aoffset4 = aoffset3 + lda;
2001
+ aoffset5 = aoffset4 + lda;
2002
+ aoffset6 = aoffset5 + lda;
2003
+ aoffset7 = aoffset6 + lda;
2004
+ aoffset8 = aoffset7 + lda;
2005
+ aoffset += 8 * lda;
1128
2006
 
1129
- i = (cols >> 3);
1130
- if (i > 0) {
1131
- do {
2007
+ i = (cols >> 3);
2008
+ if (i > 0) {
2009
+ do {
1132
2010
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1133
2011
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1134
2012
  C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1156,10 +2034,10 @@ class tinyBLAS_Q0_PPC {
1156
2034
  t7 = vec_perm(t2, t4, swiz3);
1157
2035
  t8 = vec_perm(t2, t4, swiz4);
1158
2036
  if (flip == true) {
1159
- t5 = vec_xor(t5, xor_vector);
1160
- t6 = vec_xor(t6, xor_vector);
1161
- t7 = vec_xor(t7, xor_vector);
1162
- t8 = vec_xor(t8, xor_vector);
2037
+ t5 = vec_xor(t5, xor_vector);
2038
+ t6 = vec_xor(t6, xor_vector);
2039
+ t7 = vec_xor(t7, xor_vector);
2040
+ t8 = vec_xor(t8, xor_vector);
1163
2041
  }
1164
2042
  vec_xst(t5, 0, vecOffset);
1165
2043
  vec_xst(t6, 0, vecOffset+16);
@@ -1175,10 +2053,10 @@ class tinyBLAS_Q0_PPC {
1175
2053
  t7 = vec_perm(t2, t4, swiz3);
1176
2054
  t8 = vec_perm(t2, t4, swiz4);
1177
2055
  if (flip == true) {
1178
- t5 = vec_xor(t5, xor_vector);
1179
- t6 = vec_xor(t6, xor_vector);
1180
- t7 = vec_xor(t7, xor_vector);
1181
- t8 = vec_xor(t8, xor_vector);
2056
+ t5 = vec_xor(t5, xor_vector);
2057
+ t6 = vec_xor(t6, xor_vector);
2058
+ t7 = vec_xor(t7, xor_vector);
2059
+ t8 = vec_xor(t8, xor_vector);
1182
2060
  }
1183
2061
  vec_xst(t5, 0, vecOffset+64);
1184
2062
  vec_xst(t6, 0, vecOffset+80);
@@ -1194,10 +2072,10 @@ class tinyBLAS_Q0_PPC {
1194
2072
  t7 = vec_perm(t2, t4, swiz3);
1195
2073
  t8 = vec_perm(t2, t4, swiz4);
1196
2074
  if (flip == true) {
1197
- t5 = vec_xor(t5, xor_vector);
1198
- t6 = vec_xor(t6, xor_vector);
1199
- t7 = vec_xor(t7, xor_vector);
1200
- t8 = vec_xor(t8, xor_vector);
2075
+ t5 = vec_xor(t5, xor_vector);
2076
+ t6 = vec_xor(t6, xor_vector);
2077
+ t7 = vec_xor(t7, xor_vector);
2078
+ t8 = vec_xor(t8, xor_vector);
1201
2079
  }
1202
2080
  vec_xst(t5, 0, vecOffset+128);
1203
2081
  vec_xst(t6, 0, vecOffset+144);
@@ -1213,10 +2091,10 @@ class tinyBLAS_Q0_PPC {
1213
2091
  t7 = vec_perm(t2, t4, swiz3);
1214
2092
  t8 = vec_perm(t2, t4, swiz4);
1215
2093
  if (flip == true) {
1216
- t5 = vec_xor(t5, xor_vector);
1217
- t6 = vec_xor(t6, xor_vector);
1218
- t7 = vec_xor(t7, xor_vector);
1219
- t8 = vec_xor(t8, xor_vector);
2094
+ t5 = vec_xor(t5, xor_vector);
2095
+ t6 = vec_xor(t6, xor_vector);
2096
+ t7 = vec_xor(t7, xor_vector);
2097
+ t8 = vec_xor(t8, xor_vector);
1220
2098
  }
1221
2099
  vec_xst(t5, 0, vecOffset+192);
1222
2100
  vec_xst(t6, 0, vecOffset+208);
@@ -1240,11 +2118,11 @@ class tinyBLAS_Q0_PPC {
1240
2118
  }
1241
2119
 
1242
2120
  if (rows & 4) {
1243
- aoffset1 = aoffset;
1244
- aoffset2 = aoffset1 + lda;
1245
- aoffset3 = aoffset2 + lda;
1246
- aoffset4 = aoffset3 + lda;
1247
- aoffset += 4 * lda;
2121
+ aoffset1 = aoffset;
2122
+ aoffset2 = aoffset1 + lda;
2123
+ aoffset3 = aoffset2 + lda;
2124
+ aoffset4 = aoffset3 + lda;
2125
+ aoffset += 4 * lda;
1248
2126
 
1249
2127
  i = (cols >> 3);
1250
2128
  if (i > 0) {
@@ -1311,7 +2189,7 @@ class tinyBLAS_Q0_PPC {
1311
2189
  aoffset2 = aoffset1 + lda;
1312
2190
  aoffset3 = aoffset2 + lda;
1313
2191
  i = (cols >> 3);
1314
- if (i > 0) {
2192
+ if (i > 0) {
1315
2193
  do {
1316
2194
  switch(rows) {
1317
2195
  case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1527,13 +2405,18 @@ class tinyBLAS_Q0_PPC {
1527
2405
  void KERNEL_4x8(int64_t ii, int64_t jj) {
1528
2406
  vec_t vec_A[8], vec_B[16] = {0};
1529
2407
  acc_t acc_0, acc_1;
1530
- std::array<int, 4> comparray;
2408
+ std::array<int, 4> comparray {};
1531
2409
  vector float fin_res[8] = {0};
1532
2410
  vector float vs[8] = {0};
2411
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1533
2412
  for (int l = 0; l < k; l++) {
1534
2413
  __builtin_mma_xxsetaccz(&acc_0);
1535
2414
  __builtin_mma_xxsetaccz(&acc_1);
1536
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2415
+ if (std::is_same_v<TA, block_q4_0>) {
2416
+ packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2417
+ } else {
2418
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2419
+ }
1537
2420
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1538
2421
  for(int x = 0; x < 8; x++) {
1539
2422
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1545,15 +2428,17 @@ class tinyBLAS_Q0_PPC {
1545
2428
  *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1546
2429
  }
1547
2430
  }
1548
- auto aoffset = A+(ii*lda)+l;
1549
- for (int i = 0; i < 4; i++) {
1550
- comparray[i] = 0;
1551
- int ca = 0;
1552
- const int8_t *at = aoffset->qs;
1553
- for (int j = 0; j < 32; j++)
1554
- ca += (int)*at++;
1555
- comparray[i] = ca;
1556
- aoffset += lda;
2431
+ if (!isAblock_q4) {
2432
+ auto aoffset = A+(ii*lda)+l;
2433
+ for (int i = 0; i < 4; i++) {
2434
+ comparray[i] = 0;
2435
+ int ca = 0;
2436
+ auto *at = aoffset->qs;
2437
+ for (int j = 0; j < 32; j++)
2438
+ ca += (int)*at++;
2439
+ comparray[i] = ca;
2440
+ aoffset += lda;
2441
+ }
1557
2442
  }
1558
2443
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1559
2444
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
@@ -1565,13 +2450,18 @@ class tinyBLAS_Q0_PPC {
1565
2450
  void KERNEL_8x4(int64_t ii, int64_t jj) {
1566
2451
  vec_t vec_A[16], vec_B[8] = {0};
1567
2452
  acc_t acc_0, acc_1;
1568
- std::array<int, 8> comparray;
2453
+ std::array<int, 8> comparray {};
1569
2454
  vector float fin_res[8] = {0};
1570
2455
  vector float vs[8] = {0};
2456
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1571
2457
  for (int l = 0; l < k; l++) {
1572
2458
  __builtin_mma_xxsetaccz(&acc_0);
1573
2459
  __builtin_mma_xxsetaccz(&acc_1);
1574
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2460
+ if (std::is_same_v<TA, block_q4_0>) {
2461
+ packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2462
+ } else {
2463
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2464
+ }
1575
2465
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1576
2466
  for(int x = 0; x < 8; x++) {
1577
2467
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1582,15 +2472,17 @@ class tinyBLAS_Q0_PPC {
1582
2472
  *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1583
2473
  }
1584
2474
  }
1585
- auto aoffset = A+(ii*lda)+l;
1586
- for (int i = 0; i < 8; i++) {
1587
- comparray[i] = 0;
1588
- int ca = 0;
1589
- const int8_t *at = aoffset->qs;
1590
- for (int j = 0; j < 32; j++)
1591
- ca += (int)*at++;
1592
- comparray[i] = ca;
1593
- aoffset += lda;
2475
+ if (!isAblock_q4) {
2476
+ auto aoffset = A+(ii*lda)+l;
2477
+ for (int i = 0; i < 8; i++) {
2478
+ comparray[i] = 0;
2479
+ int ca = 0;
2480
+ auto *at = aoffset->qs;
2481
+ for (int j = 0; j < 32; j++)
2482
+ ca += (int)*at++;
2483
+ comparray[i] = ca;
2484
+ aoffset += lda;
2485
+ }
1594
2486
  }
1595
2487
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1596
2488
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1602,15 +2494,20 @@ class tinyBLAS_Q0_PPC {
1602
2494
  void KERNEL_8x8(int64_t ii, int64_t jj) {
1603
2495
  vec_t vec_A[16], vec_B[16] = {0};
1604
2496
  acc_t acc_0, acc_1, acc_2, acc_3;
1605
- std::array<int, 8> comparray;
2497
+ std::array<int, 8> comparray {};
1606
2498
  vector float fin_res[16] = {0};
1607
2499
  vector float vs[16] = {0};
2500
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1608
2501
  for (int l = 0; l < k; l++) {
1609
2502
  __builtin_mma_xxsetaccz(&acc_0);
1610
2503
  __builtin_mma_xxsetaccz(&acc_1);
1611
2504
  __builtin_mma_xxsetaccz(&acc_2);
1612
2505
  __builtin_mma_xxsetaccz(&acc_3);
1613
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2506
+ if (std::is_same_v<TA, block_q4_0>) {
2507
+ packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2508
+ } else {
2509
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2510
+ }
1614
2511
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1615
2512
  for(int x = 0; x < 8; x++) {
1616
2513
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1624,15 +2521,17 @@ class tinyBLAS_Q0_PPC {
1624
2521
  *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1625
2522
  }
1626
2523
  }
1627
- auto aoffset = A+(ii*lda)+l;
1628
- for (int i = 0; i < 8; i++) {
1629
- comparray[i] = 0;
1630
- int ca = 0;
1631
- const int8_t *at = aoffset->qs;
1632
- for (int j = 0; j < 32; j++)
1633
- ca += (int)*at++;
1634
- comparray[i] = ca;
1635
- aoffset += lda;
2524
+ if (!isAblock_q4) {
2525
+ auto aoffset = A+(ii*lda)+l;
2526
+ for (int i = 0; i < 8; i++) {
2527
+ comparray[i] = 0;
2528
+ int ca = 0;
2529
+ auto *at = aoffset->qs;
2530
+ for (int j = 0; j < 32; j++)
2531
+ ca += (int)*at++;
2532
+ comparray[i] = ca;
2533
+ aoffset += lda;
2534
+ }
1636
2535
  }
1637
2536
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1638
2537
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1653,16 +2552,17 @@ class tinyBLAS_Q0_PPC {
1653
2552
  int64_t duty = (tiles + nth - 1) / nth;
1654
2553
  int64_t start = duty * ith;
1655
2554
  int64_t end = start + duty;
1656
- vec_t vec_A[8], vec_B[8] = {0};
2555
+ vec_t vec_A[8] = {0}, vec_B[8] = {0};
1657
2556
  vector signed int vec_C[4];
1658
2557
  acc_t acc_0;
2558
+ bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1659
2559
 
1660
2560
  if (end > tiles)
1661
2561
  end = tiles;
1662
2562
  for (int64_t job = start; job < end; ++job) {
1663
2563
  int64_t ii = m0 + job / xtiles * RM;
1664
2564
  int64_t jj = n0 + job % xtiles * RN;
1665
- std::array<int, RM> comparray;
2565
+ std::array<int, 4> comparray{};
1666
2566
  vector float res[4] = {0};
1667
2567
  vector float fin_res[4] = {0};
1668
2568
  vector float vs[4] = {0};
@@ -1673,7 +2573,11 @@ class tinyBLAS_Q0_PPC {
1673
2573
  __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1674
2574
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1675
2575
  __builtin_mma_xxsetaccz(&acc_0);
1676
- packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2576
+ if (isAblock_q4) {
2577
+ packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2578
+ } else {
2579
+ packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2580
+ }
1677
2581
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
1678
2582
  for(int x = 0; x < 8; x+=4) {
1679
2583
  __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1687,17 +2591,18 @@ class tinyBLAS_Q0_PPC {
1687
2591
  }
1688
2592
  }
1689
2593
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
1690
- auto aoffset = A+(ii*lda)+l;
1691
- for (int i = 0; i < RM; i++) {
1692
- comparray[i] = 0;
1693
- int ca = 0;
1694
- const int8_t *at = aoffset->qs;
1695
- for (int j = 0; j < 32; j++)
1696
- ca += (int)*at++;
1697
- comparray[i] = ca;
1698
- aoffset += lda;
2594
+ if (!isAblock_q4) {
2595
+ auto aoffset = A+(ii*lda)+l;
2596
+ for (int i = 0; i < RM; i++) {
2597
+ comparray[i] = 0;
2598
+ int ca = 0;
2599
+ auto *at = aoffset->qs;
2600
+ for (int j = 0; j < 32; j++)
2601
+ ca += (int)*at++;
2602
+ comparray[i] = ca;
2603
+ aoffset += lda;
2604
+ }
1699
2605
  }
1700
-
1701
2606
  for (int i = 0; i < RM; i++) {
1702
2607
  CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
1703
2608
  res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
@@ -1784,6 +2689,7 @@ class tinyBLAS_PPC {
1784
2689
  boffset = vec;
1785
2690
  j = (rows >> 3);
1786
2691
  if (j > 0) {
2692
+
1787
2693
  do {
1788
2694
  aoffset1 = aoffset;
1789
2695
  aoffset2 = aoffset1 + lda;
@@ -2013,6 +2919,7 @@ class tinyBLAS_PPC {
2013
2919
  }
2014
2920
  }
2015
2921
  }
2922
+
2016
2923
  void KERNEL_4x4(int64_t ii, int64_t jj) {
2017
2924
  vec_t vec_A[4], vec_B[4], vec_C[4];
2018
2925
  acc_t acc_0;
@@ -2259,15 +3166,27 @@ class tinyBLAS_PPC {
2259
3166
  vec_t vec_C[4];
2260
3167
  acc_t acc_0;
2261
3168
  __builtin_mma_xxsetaccz(&acc_0);
2262
- vec_t vec_A[4], vec_B[4];
3169
+ vec_t vec_A[4] {0}, vec_B[4] = {0};
2263
3170
  for (int l=0; l<k; l+=4) {
2264
- if (RN >= 4 && RM == 1) {
3171
+ /* 'GEMV Forwarding' concept is used in first two conditional loops.
3172
+ * when one of the matrix has a single row/column, the elements are
3173
+ * broadcasted, instead of using packing routine to prepack the
3174
+ * matrix elements.
3175
+ */
3176
+ if (RM == 1) {
2265
3177
  TA* a = const_cast<TA*>(A+(ii)*lda+l);
2266
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
3178
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2267
3179
  vec_A[0] = (vec_t)vec_xl(0,a);
2268
3180
  vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
2269
3181
  vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
2270
3182
  vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
3183
+ } else if (RN == 1) {
3184
+ packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3185
+ TB* b = const_cast<TB*>(B+(jj)*ldb+l);
3186
+ vec_B[0] = (vec_t)vec_xl(0,b);
3187
+ vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3188
+ vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3189
+ vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
2271
3190
  } else {
2272
3191
  packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2273
3192
  packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@@ -2371,8 +3290,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2371
3290
  assert(params->ith < params->nth);
2372
3291
 
2373
3292
  // only enable sgemm for prompt processing
3293
+ #if !defined(__MMA__)
2374
3294
  if (n < 2)
2375
3295
  return false;
3296
+ #endif
2376
3297
 
2377
3298
  if (Ctype != GGML_TYPE_F32)
2378
3299
  return false;
@@ -2442,9 +3363,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2442
3363
  (float *)C, ldc};
2443
3364
  return tb.matmul(m, n);
2444
3365
  }
3366
+ #elif defined(__MMA__)
3367
+ if ((k % 8))
3368
+ return false;
3369
+ if(Btype == GGML_TYPE_BF16) {
3370
+ tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3371
+ (const ggml_bf16_t *)A, lda,
3372
+ (const ggml_bf16_t *)B, ldb,
3373
+ (float *)C, ldc,
3374
+ params->ith, params->nth};
3375
+ tb.matmul(m, n);
3376
+ return true;
3377
+ }
2445
3378
  #endif
2446
3379
  return false;
2447
3380
  }
3381
+
2448
3382
  case GGML_TYPE_F16: {
2449
3383
  #if defined(__AVX512F__)
2450
3384
  if (Btype == GGML_TYPE_F16) {
@@ -2503,8 +3437,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2503
3437
  params->ith, params->nth};
2504
3438
  tb.matmul(m, n);
2505
3439
  return true;
2506
-
2507
3440
  #elif defined(__MMA__)
3441
+ //TO-DO: Remove this condition once gemv forwarding is enabled.
2508
3442
  if (n < 8 && n != 4)
2509
3443
  return false;
2510
3444
  if (m < 8 && m != 4)
@@ -2516,7 +3450,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2516
3450
  params->ith, params->nth};
2517
3451
  tb.matmul(m, n);
2518
3452
  return true;
2519
-
2520
3453
  #else
2521
3454
  return false;
2522
3455
  #endif
@@ -2541,6 +3474,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
2541
3474
  params->ith, params->nth};
2542
3475
  tb.matmul(m, n);
2543
3476
  return true;
3477
+ #elif defined(__MMA__)
3478
+ //TO-DO: Remove this condition once gemv forwarding is enabled.
3479
+ if (n < 8 && n != 4)
3480
+ return false;
3481
+ if (m < 8 && m != 4)
3482
+ return false;
3483
+ tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
3484
+ k, (const block_q4_0 *)A, lda,
3485
+ (const block_q8_0 *)B, ldb,
3486
+ (float *)C, ldc,
3487
+ params->ith, params->nth};
3488
+ tb.matmul(m, n);
3489
+ return true;
2544
3490
  #else
2545
3491
  return false;
2546
3492
  #endif