@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -50,7 +50,8 @@
50
50
 
51
51
  #include "sgemm.h"
52
52
  #include "ggml-impl.h"
53
- #include "ggml-cpu-impl.h"
53
+ // hack until moved into the CPU backend
54
+ #include "../ggml-cpu-impl.h"
54
55
  #include "ggml-quants.h"
55
56
 
56
57
  #ifdef _MSC_VER
@@ -106,6 +107,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
106
107
  inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
107
108
  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
108
109
 
110
+ #if defined(__MMA__)
111
+ typedef vector unsigned char vec_t;
112
+ typedef __vector_quad acc_t;
113
+ #endif
109
114
  ////////////////////////////////////////////////////////////////////////////////////////////////////
110
115
  // VECTORIZED FUSED MULTIPLY ADD
111
116
 
@@ -942,6 +947,36 @@ class tinyBLAS_Q0_AVX {
942
947
  return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
943
948
  }
944
949
 
950
+ inline __m256i load(const block_q5_0 *b) {
951
+ return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
952
+ }
953
+
954
+ inline __m128i load0(const block_q5_0* b) {
955
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
956
+ uint32_t x32;
957
+ memcpy(&x32, b->qh, sizeof(uint32_t));
958
+ __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
959
+ __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
960
+ _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
961
+ _mm_shuffle_epi8(_mm_set1_epi32(x32),
962
+ _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
963
+ bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
964
+ return _mm_or_si128(qxl, bytesl);
965
+ }
966
+
967
+ inline __m128i load1(const block_q5_0* b) {
968
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
969
+ uint32_t x32;
970
+ memcpy(&x32, b->qh, sizeof(uint32_t));
971
+ __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
972
+ __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
973
+ _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
974
+ _mm_shuffle_epi8(_mm_set1_epi32(x32),
975
+ _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
976
+ bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
977
+ return _mm_or_si128(qxh, bytesh);
978
+ }
979
+
945
980
  inline __m256i load(const block_iq4_nl *b) {
946
981
  return MM256_SET_M128I(load1(b), load0(b));
947
982
  }
@@ -973,6 +1008,17 @@ class tinyBLAS_Q0_AVX {
973
1008
  _mm_srli_epi16(x, 4), 1));
974
1009
  }
975
1010
 
1011
+ static inline __m256i bittobyte(const uint8_t *p) {
1012
+ uint32_t x32;
1013
+ memcpy(&x32, p, sizeof(uint32_t));
1014
+ __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1015
+ _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1016
+ _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1017
+ _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1018
+ 0x0101010101010101, 0x0000000000000000))));
1019
+ return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1020
+ }
1021
+
976
1022
  const TA *const A;
977
1023
  const TB *const B;
978
1024
  TC *const C;
@@ -985,6 +1031,600 @@ class tinyBLAS_Q0_AVX {
985
1031
  };
986
1032
  #endif // __AVX__
987
1033
 
1034
+ //PPC Implementation
1035
+ #if defined(__MMA__)
1036
+
1037
+ #define SAVE_ACC(ACC, ii, jj) \
1038
+ __builtin_mma_disassemble_acc(vec_C, ACC); \
1039
+ for (int I = 0; I < 4; I++) { \
1040
+ for (int J = 0; J < 4; J++) { \
1041
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1042
+ } \
1043
+ } \
1044
+
1045
+ template <typename TA, typename TB, typename TC>
1046
+ class tinyBLAS_PPC {
1047
+ public:
1048
+ tinyBLAS_PPC(int64_t k,
1049
+ const TA *A, int64_t lda,
1050
+ const TB *B, int64_t ldb,
1051
+ TC *C, int64_t ldc,
1052
+ int ith, int nth)
1053
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1054
+ }
1055
+
1056
+ void matmul(int64_t m, int64_t n) {
1057
+ mnpack(0, m, 0, n);
1058
+ }
1059
+
1060
+ private:
1061
+
1062
+ void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1063
+
1064
+ void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
1065
+ int64_t i, j;
1066
+ float *aoffset = NULL, *boffset = NULL;
1067
+ float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1068
+ float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1069
+
1070
+ aoffset = const_cast<float*>(a);
1071
+ boffset = vec;
1072
+ j = (rows >> 3);
1073
+ if (j > 0) {
1074
+ do {
1075
+ aoffset1 = aoffset;
1076
+ aoffset2 = aoffset1 + lda;
1077
+ aoffset3 = aoffset2 + lda;
1078
+ aoffset4 = aoffset3 + lda;
1079
+ aoffset5 = aoffset4 + lda;
1080
+ aoffset6 = aoffset5 + lda;
1081
+ aoffset7 = aoffset6 + lda;
1082
+ aoffset8 = aoffset7 + lda;
1083
+ aoffset += 8 * lda;
1084
+ i = (cols >> 3);
1085
+ if (i > 0) {
1086
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1087
+ vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
1088
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1089
+ do {
1090
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1091
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1092
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1093
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1094
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
1095
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
1096
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
1097
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
1098
+ __builtin_vsx_disassemble_pair(c1, &C1);
1099
+ __builtin_vsx_disassemble_pair(c2, &C2);
1100
+ __builtin_vsx_disassemble_pair(c3, &C3);
1101
+ __builtin_vsx_disassemble_pair(c4, &C4);
1102
+ __builtin_vsx_disassemble_pair(c5, &C5);
1103
+ __builtin_vsx_disassemble_pair(c6, &C6);
1104
+ __builtin_vsx_disassemble_pair(c7, &C7);
1105
+ __builtin_vsx_disassemble_pair(c8, &C8);
1106
+
1107
+ t1 = vec_mergeh(c1[0], c2[0]);
1108
+ t2 = vec_mergeh(c3[0], c4[0]);
1109
+ t3 = vec_mergeh(c5[0], c6[0]);
1110
+ t4 = vec_mergeh(c7[0], c8[0]);
1111
+ t5 = vec_xxpermdi(t1, t2, 0);
1112
+ t6 = vec_xxpermdi(t3, t4, 0);
1113
+ t7 = vec_xxpermdi(t1, t2, 3);
1114
+ t8 = vec_xxpermdi(t3, t4, 3);
1115
+ vec_xst(t5, 0, boffset);
1116
+ vec_xst(t6, 0, boffset+4);
1117
+ vec_xst(t7, 0, boffset+8);
1118
+ vec_xst(t8, 0, boffset+12);
1119
+
1120
+ t1 = vec_mergel(c1[0], c2[0]);
1121
+ t2 = vec_mergel(c3[0], c4[0]);
1122
+ t3 = vec_mergel(c5[0], c6[0]);
1123
+ t4 = vec_mergel(c7[0], c8[0]);
1124
+ t5 = vec_xxpermdi(t1, t2, 0);
1125
+ t6 = vec_xxpermdi(t3, t4, 0);
1126
+ t7 = vec_xxpermdi(t1, t2, 3);
1127
+ t8 = vec_xxpermdi(t3, t4, 3);
1128
+ vec_xst(t5, 0, boffset+16);
1129
+ vec_xst(t6, 0, boffset+20);
1130
+ vec_xst(t7, 0, boffset+24);
1131
+ vec_xst(t8, 0, boffset+28);
1132
+
1133
+ t1 = vec_mergeh(c1[1], c2[1]);
1134
+ t2 = vec_mergeh(c3[1], c4[1]);
1135
+ t3 = vec_mergeh(c5[1], c6[1]);
1136
+ t4 = vec_mergeh(c7[1], c8[1]);
1137
+ t5 = vec_xxpermdi(t1, t2, 0);
1138
+ t6 = vec_xxpermdi(t3, t4, 0);
1139
+ t7 = vec_xxpermdi(t1, t2, 3);
1140
+ t8 = vec_xxpermdi(t3, t4, 3);
1141
+ vec_xst(t5, 0, boffset+32);
1142
+ vec_xst(t6, 0, boffset+36);
1143
+ vec_xst(t7, 0, boffset+40);
1144
+ vec_xst(t8, 0, boffset+44);
1145
+
1146
+ t1 = vec_mergel(c1[1], c2[1]);
1147
+ t2 = vec_mergel(c3[1], c4[1]);
1148
+ t3 = vec_mergel(c5[1], c6[1]);
1149
+ t4 = vec_mergel(c7[1], c8[1]);
1150
+ t5 = vec_xxpermdi(t1, t2, 0);
1151
+ t6 = vec_xxpermdi(t3, t4, 0);
1152
+ t7 = vec_xxpermdi(t1, t2, 3);
1153
+ t8 = vec_xxpermdi(t3, t4, 3);
1154
+ vec_xst(t5, 0, boffset+48);
1155
+ vec_xst(t6, 0, boffset+52);
1156
+ vec_xst(t7, 0, boffset+56);
1157
+ vec_xst(t8, 0, boffset+60);
1158
+
1159
+ aoffset1 += 8*lda;
1160
+ aoffset2 += 8*lda;
1161
+ aoffset3 += 8*lda;
1162
+ aoffset4 += 8*lda;
1163
+ boffset += 64;
1164
+ i--;
1165
+ } while(i > 0);
1166
+ }
1167
+ if (cols & 4) {
1168
+ vector float c1, c2, c3, c4, c5, c6, c7, c8;
1169
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1170
+ c1 = vec_xl(0, aoffset1);
1171
+ c2 = vec_xl(0, aoffset2);
1172
+ c3 = vec_xl(0, aoffset3);
1173
+ c4 = vec_xl(0, aoffset4);
1174
+ c5 = vec_xl(0, aoffset5);
1175
+ c6 = vec_xl(0, aoffset6);
1176
+ c7 = vec_xl(0, aoffset7);
1177
+ c8 = vec_xl(0, aoffset8);
1178
+
1179
+ t1 = vec_mergeh(c1, c2);
1180
+ t2 = vec_mergeh(c3, c4);
1181
+ t3 = vec_mergeh(c5, c6);
1182
+ t4 = vec_mergeh(c7, c8);
1183
+ t5 = vec_xxpermdi(t1, t2, 0);
1184
+ t6 = vec_xxpermdi(t3, t4, 0);
1185
+ t7 = vec_xxpermdi(t1, t2, 3);
1186
+ t8 = vec_xxpermdi(t3, t4, 3);
1187
+ vec_xst(t5, 0, boffset);
1188
+ vec_xst(t6, 0, boffset+4);
1189
+ vec_xst(t7, 0, boffset+8);
1190
+ vec_xst(t8, 0, boffset+12);
1191
+
1192
+ t1 = vec_mergel(c1, c2);
1193
+ t2 = vec_mergel(c3, c4);
1194
+ t3 = vec_mergel(c5, c6);
1195
+ t4 = vec_mergel(c7, c8);
1196
+ t5 = vec_xxpermdi(t1, t2, 0);
1197
+ t6 = vec_xxpermdi(t3, t4, 0);
1198
+ t7 = vec_xxpermdi(t1, t2, 3);
1199
+ t8 = vec_xxpermdi(t3, t4, 3);
1200
+ vec_xst(t5, 0, boffset+16);
1201
+ vec_xst(t6, 0, boffset+20);
1202
+ vec_xst(t7, 0, boffset+24);
1203
+ vec_xst(t8, 0, boffset+28);
1204
+ }
1205
+ j--;
1206
+ } while(j > 0);
1207
+ }
1208
+
1209
+ if (rows & 4) {
1210
+ aoffset1 = aoffset;
1211
+ aoffset2 = aoffset1 + lda;
1212
+ aoffset3 = aoffset2 + lda;
1213
+ aoffset4 = aoffset3 + lda;
1214
+ aoffset += 4 * lda;
1215
+ i = (cols >> 3);
1216
+ if (i > 0) {
1217
+ __vector_pair C1, C2, C3, C4;
1218
+ vector float c1[2], c2[2], c3[2], c4[2];
1219
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1220
+ do {
1221
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1222
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1223
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1224
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1225
+ __builtin_vsx_disassemble_pair(c1, &C1);
1226
+ __builtin_vsx_disassemble_pair(c2, &C2);
1227
+ __builtin_vsx_disassemble_pair(c3, &C3);
1228
+ __builtin_vsx_disassemble_pair(c4, &C4);
1229
+
1230
+ t1 = vec_mergeh(c1[0], c2[0]);
1231
+ t2 = vec_mergeh(c3[0], c4[0]);
1232
+ t3 = vec_mergel(c1[0], c2[0]);
1233
+ t4 = vec_mergel(c3[0], c4[0]);
1234
+ t5 = vec_xxpermdi(t1, t2, 0);
1235
+ t6 = vec_xxpermdi(t1, t2, 3);
1236
+ t7 = vec_xxpermdi(t3, t4, 0);
1237
+ t8 = vec_xxpermdi(t3, t4, 3);
1238
+ vec_xst(t5, 0, boffset);
1239
+ vec_xst(t6, 0, boffset+4);
1240
+ vec_xst(t7, 0, boffset+8);
1241
+ vec_xst(t8, 0, boffset+12);
1242
+
1243
+ t1 = vec_mergeh(c1[1], c2[1]);
1244
+ t2 = vec_mergeh(c3[1], c4[1]);
1245
+ t3 = vec_mergel(c1[1], c2[1]);
1246
+ t4 = vec_mergel(c3[1], c4[1]);
1247
+ t5 = vec_xxpermdi(t1, t2, 0);
1248
+ t6 = vec_xxpermdi(t1, t2, 3);
1249
+ t7 = vec_xxpermdi(t3, t4, 0);
1250
+ t8 = vec_xxpermdi(t3, t4, 3);
1251
+ vec_xst(t5, 0, boffset+16);
1252
+ vec_xst(t6, 0, boffset+20);
1253
+ vec_xst(t7, 0, boffset+24);
1254
+ vec_xst(t8, 0, boffset+28);
1255
+
1256
+ aoffset1 += 8*lda;
1257
+ aoffset2 += 8*lda;
1258
+ aoffset3 += 8*lda;
1259
+ aoffset4 += 8*lda;
1260
+ boffset += 32;
1261
+ i--;
1262
+ } while(i > 0);
1263
+ }
1264
+
1265
+ if (cols & 4) {
1266
+ vector float c1, c2, c3, c4;
1267
+ vector float t1, t2, t3, t4;
1268
+ c1 = vec_xl(0, aoffset1);
1269
+ c2 = vec_xl(0, aoffset2);
1270
+ c3 = vec_xl(0, aoffset3);
1271
+ c4 = vec_xl(0, aoffset4);
1272
+
1273
+ t1 = vec_mergeh(c1, c2);
1274
+ t2 = vec_mergeh(c3, c4);
1275
+ t3 = vec_xxpermdi(t1, t2, 0);
1276
+ t4 = vec_xxpermdi(t1, t2, 3);
1277
+ vec_xst(t3, 0, boffset);
1278
+ vec_xst(t4, 0, boffset+4);
1279
+
1280
+ t1 = vec_mergel(c1, c2);
1281
+ t2 = vec_mergel(c3, c4);
1282
+ t3 = vec_xxpermdi(t1, t2, 0);
1283
+ t4 = vec_xxpermdi(t1, t2, 3);
1284
+ vec_xst(t3, 0, boffset+8);
1285
+ vec_xst(t4, 0, boffset+12);
1286
+ }
1287
+ }
1288
+ if (rows & 3) {
1289
+ aoffset1 = aoffset;
1290
+ aoffset2 = aoffset1 + lda;
1291
+ aoffset3 = aoffset2 + lda;
1292
+ if (cols & 4) {
1293
+ vector float c1, c2, c3, c4 = {0};
1294
+ vector float t1, t2, t3, t4;
1295
+ c1 = vec_xl(0, aoffset1);
1296
+ c2 = vec_xl(0, aoffset2);
1297
+ c3 = vec_xl(0, aoffset3);
1298
+
1299
+ t1 = vec_mergeh(c1, c2);
1300
+ t2 = vec_mergeh(c3, c4);
1301
+ t3 = vec_xxpermdi(t1, t2, 0);
1302
+ t4 = vec_xxpermdi(t1, t2, 3);
1303
+ vec_xst(t3, 0, boffset);
1304
+ vec_xst(t4, 0, boffset+4);
1305
+
1306
+ t1 = vec_mergel(c1, c2);
1307
+ t2 = vec_mergel(c3, c4);
1308
+ t3 = vec_xxpermdi(t1, t2, 0);
1309
+ t4 = vec_xxpermdi(t1, t2, 3);
1310
+ vec_xst(t3, 0, boffset+8);
1311
+ vec_xst(t4, 0, boffset+12);
1312
+ }
1313
+ }
1314
+ }
1315
+
1316
+ void KERNEL_4x4(int64_t ii, int64_t jj) {
1317
+ vec_t vec_A[4], vec_B[4], vec_C[4];
1318
+ acc_t acc_0;
1319
+ __builtin_mma_xxsetaccz(&acc_0);
1320
+ for (int l = 0; l < k; l+=4) {
1321
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1322
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1323
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1324
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1325
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1326
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1327
+ }
1328
+ SAVE_ACC(&acc_0, ii, jj);
1329
+ }
1330
+
1331
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1332
+ vec_t vec_A[4], vec_B[8], vec_C[4];
1333
+ acc_t acc_0, acc_1;
1334
+ __builtin_mma_xxsetaccz(&acc_0);
1335
+ __builtin_mma_xxsetaccz(&acc_1);
1336
+ for (int64_t l = 0; l < k; l+=4) {
1337
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1338
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
1339
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
1340
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
1341
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
1342
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
1343
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
1344
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
1345
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
1346
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
1347
+ }
1348
+ SAVE_ACC(&acc_0, ii, jj);
1349
+ SAVE_ACC(&acc_1, ii, jj+4);
1350
+ }
1351
+
1352
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1353
+ vec_t vec_A[8], vec_B[4], vec_C[4];
1354
+ acc_t acc_0, acc_1;
1355
+ __builtin_mma_xxsetaccz(&acc_0);
1356
+ __builtin_mma_xxsetaccz(&acc_1);
1357
+ for (int64_t l = 0; l < k; l+=4) {
1358
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
1359
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1360
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
1361
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
1362
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
1363
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
1364
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
1365
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
1366
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
1367
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
1368
+ }
1369
+ SAVE_ACC(&acc_0, ii, jj);
1370
+ SAVE_ACC(&acc_1, ii+4, jj);
1371
+ }
1372
+
1373
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1374
+ vec_t vec_A[16], vec_B[16], vec_C[4];
1375
+ acc_t acc_0, acc_1, acc_2, acc_3;
1376
+ __builtin_mma_xxsetaccz(&acc_0);
1377
+ __builtin_mma_xxsetaccz(&acc_1);
1378
+ __builtin_mma_xxsetaccz(&acc_2);
1379
+ __builtin_mma_xxsetaccz(&acc_3);
1380
+ for (int l = 0; l < k; l+=8) {
1381
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
1382
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
1383
+ for(int x = 0; x < 16; x+=2) {
1384
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
1385
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
1386
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
1387
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
1388
+ }
1389
+ }
1390
+ SAVE_ACC(&acc_0, ii, jj);
1391
+ SAVE_ACC(&acc_1, ii, jj+4);
1392
+ SAVE_ACC(&acc_2, ii+4, jj);
1393
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1394
+ }
1395
+
1396
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1397
+ int64_t mc, nc, mp, np;
1398
+ int m_rem = MIN(m - m0, 16);
1399
+ int n_rem = MIN(n - n0, 16);
1400
+ if (m_rem >= 16 && n_rem >= 8) {
1401
+ mc = 8;
1402
+ nc = 8;
1403
+ gemm<8,8>(m0, m, n0, n);
1404
+ } else if(m_rem >= 8 && n_rem >= 16) {
1405
+ mc = 8;
1406
+ nc = 8;
1407
+ gemm<8,8>(m0, m, n0, n);
1408
+ } else if (m_rem >= 8 && n_rem >= 8) {
1409
+ mc = 8;
1410
+ nc = 8;
1411
+ gemm<8,8>(m0, m, n0, n);
1412
+ } else if (m_rem >= 4 && n_rem >= 8) {
1413
+ mc = 4;
1414
+ nc = 8;
1415
+ gemm<4,8>(m0, m, n0, n);
1416
+ } else if (m_rem >= 8 && n_rem >= 4) {
1417
+ mc = 8;
1418
+ nc = 4;
1419
+ gemm<8,4>(m0, m, n0, n);
1420
+ } else if (m_rem >= 4 && n_rem >= 4) {
1421
+ mc = 4;
1422
+ nc = 4;
1423
+ gemm<4,4>(m0, m, n0, n);
1424
+ } else if ((m_rem < 4) && (n_rem > 4)) {
1425
+ nc = 4;
1426
+ switch(m_rem) {
1427
+ case 1:
1428
+ mc = 1;
1429
+ gemm_small(m0, m, n0, n, mc, nc);
1430
+ break;
1431
+ case 2:
1432
+ mc = 2;
1433
+ gemm_small(m0, m, n0, n, mc, nc);
1434
+ break;
1435
+ case 3:
1436
+ mc = 3;
1437
+ gemm_small(m0, m, n0, n, mc, nc);
1438
+ break;
1439
+ default:
1440
+ return;
1441
+ }
1442
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1443
+ mc = 4;
1444
+ switch(n_rem) {
1445
+ case 1:
1446
+ nc = 1;
1447
+ gemm_small(m0, m, n0, n, mc, nc);
1448
+ break;
1449
+ case 2:
1450
+ nc = 2;
1451
+ gemm_small(m0, m, n0, n, mc, nc);
1452
+ break;
1453
+ case 3:
1454
+ nc = 3;
1455
+ gemm_small(m0, m, n0, n, mc, nc);
1456
+ break;
1457
+ default:
1458
+ return;
1459
+ }
1460
+ } else {
1461
+ switch((m_rem << 4) | n_rem) {
1462
+ case 0x43:
1463
+ mc = 4;
1464
+ nc = 3;
1465
+ gemm_small(m0, m, n0, n, mc, nc);
1466
+ break;
1467
+ case 0x42:
1468
+ mc = 4;
1469
+ nc = 2;
1470
+ gemm_small(m0, m, n0, n, mc, nc);
1471
+ break;
1472
+ case 0x41:
1473
+ mc = 4;
1474
+ nc = 1;
1475
+ gemm_small(m0, m, n0, n, mc, nc);
1476
+ break;
1477
+ case 0x34:
1478
+ mc = 3;
1479
+ nc = 4;
1480
+ gemm_small(m0, m, n0, n, mc, nc);
1481
+ break;
1482
+ case 0x33:
1483
+ mc = 3;
1484
+ nc = 3;
1485
+ gemm_small(m0, m, n0, n, mc, nc);
1486
+ break;
1487
+ case 0x32:
1488
+ mc = 3;
1489
+ nc = 2;
1490
+ gemm_small(m0, m, n0, n, mc, nc);
1491
+ break;
1492
+ case 0x31:
1493
+ mc = 3;
1494
+ nc = 1;
1495
+ gemm_small(m0, m, n0, n, mc, nc);
1496
+ break;
1497
+ case 0x24:
1498
+ mc = 2;
1499
+ nc = 4;
1500
+ gemm_small(m0, m, n0, n, mc, nc);
1501
+ break;
1502
+ case 0x23:
1503
+ mc = 2;
1504
+ nc = 3;
1505
+ gemm_small(m0, m, n0, n, mc, nc);
1506
+ break;
1507
+ case 0x22:
1508
+ mc = 2;
1509
+ nc = 2;
1510
+ gemm_small(m0, m, n0, n, mc, nc);
1511
+ break;
1512
+ case 0x21:
1513
+ mc = 2;
1514
+ nc = 1;
1515
+ gemm_small(m0, m, n0, n, mc, nc);
1516
+ break;
1517
+ case 0x14:
1518
+ mc = 1;
1519
+ nc = 4;
1520
+ gemm_small(m0, m, n0, n, mc, nc);
1521
+ break;
1522
+ case 0x13:
1523
+ mc = 1;
1524
+ nc = 3;
1525
+ gemm_small(m0, m, n0, n, mc, nc);
1526
+ break;
1527
+ case 0x12:
1528
+ mc = 1;
1529
+ nc = 2;
1530
+ gemm_small(m0, m, n0, n, mc, nc);
1531
+ break;
1532
+ case 0x11:
1533
+ mc = 1;
1534
+ nc = 1;
1535
+ gemm_small(m0, m, n0, n, mc, nc);
1536
+ break;
1537
+ default:
1538
+ return;
1539
+ }
1540
+ }
1541
+ mp = m0 + (m - m0) / mc * mc;
1542
+ np = n0 + (n - n0) / nc * nc;
1543
+ mnpack(mp, m, n0, np);
1544
+ mnpack(m0, m, np, n);
1545
+ }
1546
+
1547
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
1548
+ int64_t ytiles = (m - m0) / RM;
1549
+ int64_t xtiles = (n - n0) / RN;
1550
+ int64_t tiles = xtiles * ytiles;
1551
+ int64_t duty = (tiles + nth - 1) / nth;
1552
+ int64_t start = duty * ith;
1553
+ int64_t end = start + duty;
1554
+ if (end > tiles)
1555
+ end = tiles;
1556
+ for (int64_t job = start; job < end; ++job) {
1557
+ int64_t ii = m0 + job / xtiles * RM;
1558
+ int64_t jj = n0 + job % xtiles * RN;
1559
+ vec_t vec_C[4];
1560
+ acc_t acc_0;
1561
+ __builtin_mma_xxsetaccz(&acc_0);
1562
+ vec_t vec_A[4], vec_B[4];
1563
+ for (int l=0; l<k; l+=4) {
1564
+ if (RN >= 4 && RM == 1) {
1565
+ float* a = const_cast<float*>(A+(ii)*lda+l);
1566
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1567
+ vec_A[0] = (vec_t)vec_xl(0,a);
1568
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
1569
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
1570
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
1571
+ } else {
1572
+ READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
1573
+ READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
1574
+ }
1575
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1576
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1577
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1578
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1579
+ }
1580
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1581
+ for (int I = 0; I < RM; I++) {
1582
+ for (int J = 0; J < RN; J++) {
1583
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
1584
+ }
1585
+ }
1586
+ }
1587
+ }
1588
+
1589
+ template <int RM, int RN>
1590
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1591
+ int64_t ytiles = (m - m0) / RM;
1592
+ int64_t xtiles = (n - n0) / RN;
1593
+ int64_t tiles = xtiles * ytiles;
1594
+ int64_t duty = (tiles + nth - 1) / nth;
1595
+ int64_t start = duty * ith;
1596
+ int64_t end = start + duty;
1597
+ if (RM == 4 && RN == 4) {
1598
+ kernel = &tinyBLAS_PPC::KERNEL_4x4;
1599
+ } else if (RM == 4 && RN == 8) {
1600
+ kernel = &tinyBLAS_PPC::KERNEL_4x8;
1601
+ } else if (RM == 8 && RN == 4) {
1602
+ kernel = &tinyBLAS_PPC::KERNEL_8x4;
1603
+ } else if (RM == 8 && RN == 8) {
1604
+ kernel = &tinyBLAS_PPC::KERNEL_8x8;
1605
+ }
1606
+ if (end > tiles)
1607
+ end = tiles;
1608
+ for (int64_t job = start; job < end; ++job) {
1609
+ int64_t ii = m0 + job / xtiles * RM;
1610
+ int64_t jj = n0 + job % xtiles * RN;
1611
+ (this->*kernel)(ii, jj);
1612
+ }
1613
+ }
1614
+
1615
+ const TA *const A;
1616
+ const TB *const B;
1617
+ TC *C;
1618
+ TA *At;
1619
+ TB *Bt;
1620
+ const int64_t k;
1621
+ const int64_t lda;
1622
+ const int64_t ldb;
1623
+ const int64_t ldc;
1624
+ const int ith;
1625
+ const int nth;
1626
+ };
1627
+ #endif
988
1628
  } // namespace
989
1629
 
990
1630
  /**
@@ -1073,6 +1713,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1073
1713
  ith, nth};
1074
1714
  tb.matmul(m, n);
1075
1715
  return true;
1716
+ #elif defined(__MMA__)
1717
+ if (k % 8)
1718
+ return false;
1719
+ tinyBLAS_PPC<float, float, float> tb{
1720
+ k, (const float *)A, lda,
1721
+ (const float *)B, ldb,
1722
+ (float *)C, ldc,
1723
+ ith, nth};
1724
+ tb.matmul(m, n);
1725
+ return true;
1076
1726
  #else
1077
1727
  return false;
1078
1728
  #endif
@@ -1182,6 +1832,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1182
1832
  #endif
1183
1833
  }
1184
1834
 
1835
+ case GGML_TYPE_Q5_0: {
1836
+ if (Btype != GGML_TYPE_Q8_0)
1837
+ return false;
1838
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1839
+ tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
1840
+ k, (const block_q5_0 *)A, lda,
1841
+ (const block_q8_0 *)B, ldb,
1842
+ (float *)C, ldc,
1843
+ ith, nth};
1844
+ tb.matmul(m, n);
1845
+ return true;
1846
+ #else
1847
+ return false;
1848
+ #endif
1849
+ }
1850
+
1185
1851
  case GGML_TYPE_IQ4_NL: {
1186
1852
  if (Btype != GGML_TYPE_Q8_0)
1187
1853
  return false;