cui-llama.rn 1.3.5 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/README.md +22 -1
  2. package/android/src/main/CMakeLists.txt +25 -20
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
  5. package/android/src/main/jni-utils.h +94 -0
  6. package/android/src/main/jni.cpp +108 -37
  7. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
  8. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
  9. package/cpp/common.cpp +1982 -1965
  10. package/cpp/common.h +665 -657
  11. package/cpp/ggml-backend-reg.cpp +5 -0
  12. package/cpp/ggml-backend.cpp +5 -2
  13. package/cpp/ggml-cpp.h +1 -0
  14. package/cpp/ggml-cpu-aarch64.cpp +6 -1
  15. package/cpp/ggml-cpu-quants.c +5 -1
  16. package/cpp/ggml-cpu.c +14122 -14122
  17. package/cpp/ggml-cpu.cpp +627 -627
  18. package/cpp/ggml-impl.h +11 -16
  19. package/cpp/ggml-metal-impl.h +288 -0
  20. package/cpp/ggml-metal.m +2 -2
  21. package/cpp/ggml-opt.cpp +854 -0
  22. package/cpp/ggml-opt.h +216 -0
  23. package/cpp/ggml.c +0 -1276
  24. package/cpp/ggml.h +0 -140
  25. package/cpp/gguf.cpp +1325 -0
  26. package/cpp/gguf.h +202 -0
  27. package/cpp/llama-adapter.cpp +346 -0
  28. package/cpp/llama-adapter.h +73 -0
  29. package/cpp/llama-arch.cpp +1434 -0
  30. package/cpp/llama-arch.h +395 -0
  31. package/cpp/llama-batch.cpp +368 -0
  32. package/cpp/llama-batch.h +88 -0
  33. package/cpp/llama-chat.cpp +567 -0
  34. package/cpp/llama-chat.h +51 -0
  35. package/cpp/llama-context.cpp +1771 -0
  36. package/cpp/llama-context.h +128 -0
  37. package/cpp/llama-cparams.cpp +1 -0
  38. package/cpp/llama-cparams.h +37 -0
  39. package/cpp/llama-cpp.h +30 -0
  40. package/cpp/llama-grammar.cpp +1 -0
  41. package/cpp/llama-grammar.h +3 -1
  42. package/cpp/llama-hparams.cpp +71 -0
  43. package/cpp/llama-hparams.h +140 -0
  44. package/cpp/llama-impl.cpp +167 -0
  45. package/cpp/llama-impl.h +16 -136
  46. package/cpp/llama-kv-cache.cpp +718 -0
  47. package/cpp/llama-kv-cache.h +218 -0
  48. package/cpp/llama-mmap.cpp +589 -0
  49. package/cpp/llama-mmap.h +67 -0
  50. package/cpp/llama-model-loader.cpp +1011 -0
  51. package/cpp/llama-model-loader.h +158 -0
  52. package/cpp/llama-model.cpp +2202 -0
  53. package/cpp/llama-model.h +391 -0
  54. package/cpp/llama-sampling.cpp +117 -4
  55. package/cpp/llama-vocab.cpp +21 -28
  56. package/cpp/llama-vocab.h +13 -1
  57. package/cpp/llama.cpp +12547 -23528
  58. package/cpp/llama.h +31 -6
  59. package/cpp/rn-llama.hpp +90 -87
  60. package/cpp/sgemm.cpp +776 -70
  61. package/cpp/sgemm.h +14 -14
  62. package/cpp/unicode.cpp +6 -0
  63. package/ios/RNLlama.mm +47 -0
  64. package/ios/RNLlamaContext.h +3 -1
  65. package/ios/RNLlamaContext.mm +71 -14
  66. package/jest/mock.js +15 -3
  67. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  68. package/lib/commonjs/index.js +33 -37
  69. package/lib/commonjs/index.js.map +1 -1
  70. package/lib/module/NativeRNLlama.js.map +1 -1
  71. package/lib/module/index.js +31 -35
  72. package/lib/module/index.js.map +1 -1
  73. package/lib/typescript/NativeRNLlama.d.ts +26 -6
  74. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  75. package/lib/typescript/index.d.ts +21 -36
  76. package/lib/typescript/index.d.ts.map +1 -1
  77. package/llama-rn.podspec +4 -18
  78. package/package.json +2 -3
  79. package/src/NativeRNLlama.ts +32 -13
  80. package/src/index.ts +52 -47
package/cpp/sgemm.cpp CHANGED
@@ -54,6 +54,7 @@
54
54
  #include "ggml-quants.h"
55
55
 
56
56
  #include <atomic>
57
+ #include <array>
57
58
 
58
59
  #ifdef _MSC_VER
59
60
  #define NOINLINE __declspec(noinline)
@@ -1000,8 +1001,10 @@ class tinyBLAS_Q0_AVX {
1000
1001
 
1001
1002
  inline __m256 updot(__m256i u, __m256i s) {
1002
1003
  __m256i res;
1003
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
1004
+ #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
1004
1005
  res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1006
+ #elif defined(__AVXVNNI__)
1007
+ res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
1005
1008
  #else
1006
1009
  res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
1007
1010
  #endif
@@ -1049,6 +1052,704 @@ class tinyBLAS_Q0_AVX {
1049
1052
  } \
1050
1053
  } \
1051
1054
 
1055
+ template <typename TA, typename TB, typename TC>
1056
+ class tinyBLAS_Q0_PPC {
1057
+ public:
1058
+ tinyBLAS_Q0_PPC(int64_t k,
1059
+ const TA *A, int64_t lda,
1060
+ const TB *B, int64_t ldb,
1061
+ TC *C, int64_t ldc,
1062
+ int ith, int nth)
1063
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1064
+ }
1065
+
1066
+ void matmul(int64_t m, int64_t n) {
1067
+ mnpack(0, m, 0, n);
1068
+ }
1069
+
1070
+ private:
1071
+
1072
+ template<int RM, int RN>
1073
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1074
+ for (int I = 0; I < RM; I++) {
1075
+ for (int J = 0; J < RN; J++) {
1076
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1077
+ }
1078
+ }
1079
+ }
1080
+
1081
+ template<int size>
1082
+ inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1083
+ vector signed int vec_C[4];
1084
+ vector float CA[4] = {0};
1085
+ vector float res[4] = {0};
1086
+ __builtin_mma_disassemble_acc(vec_C, ACC);
1087
+ for (int i = 0; i < 4; i++) {
1088
+ CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1089
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1090
+ fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1091
+ }
1092
+ }
1093
+
1094
+ template<typename VA, typename VB>
1095
+ void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1096
+ int64_t i, j;
1097
+ TA *aoffset = NULL;
1098
+ VA *vecOffset = NULL;
1099
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1100
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1101
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1102
+ VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1103
+ VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
1104
+ VB t1, t2, t3, t4, t5, t6, t7, t8;
1105
+ vector unsigned char xor_vector;
1106
+ uint8_t flip_vec = 0x80;
1107
+ xor_vector = vec_splats(flip_vec);
1108
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1109
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1110
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1111
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1112
+
1113
+ aoffset = const_cast<TA*>(a);
1114
+ vecOffset = vec;
1115
+ j = (rows >> 3);
1116
+ if (j > 0) {
1117
+ do {
1118
+ aoffset1 = aoffset;
1119
+ aoffset2 = aoffset1 + lda;
1120
+ aoffset3 = aoffset2 + lda;
1121
+ aoffset4 = aoffset3 + lda;
1122
+ aoffset5 = aoffset4 + lda;
1123
+ aoffset6 = aoffset5 + lda;
1124
+ aoffset7 = aoffset6 + lda;
1125
+ aoffset8 = aoffset7 + lda;
1126
+ aoffset += 8 * lda;
1127
+
1128
+ i = (cols >> 3);
1129
+ if (i > 0) {
1130
+ do {
1131
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1132
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1133
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1134
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
1135
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
1136
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
1137
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
1138
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
1139
+
1140
+ __builtin_vsx_disassemble_pair(c1, &C1);
1141
+ __builtin_vsx_disassemble_pair(c2, &C2);
1142
+ __builtin_vsx_disassemble_pair(c3, &C3);
1143
+ __builtin_vsx_disassemble_pair(c4, &C4);
1144
+ __builtin_vsx_disassemble_pair(c5, &C5);
1145
+ __builtin_vsx_disassemble_pair(c6, &C6);
1146
+ __builtin_vsx_disassemble_pair(c7, &C7);
1147
+ __builtin_vsx_disassemble_pair(c8, &C8);
1148
+
1149
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1150
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1151
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1152
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1153
+ t5 = vec_perm(t1, t3, swiz3);
1154
+ t6 = vec_perm(t1, t3, swiz4);
1155
+ t7 = vec_perm(t2, t4, swiz3);
1156
+ t8 = vec_perm(t2, t4, swiz4);
1157
+ if (flip == true) {
1158
+ t5 = vec_xor(t5, xor_vector);
1159
+ t6 = vec_xor(t6, xor_vector);
1160
+ t7 = vec_xor(t7, xor_vector);
1161
+ t8 = vec_xor(t8, xor_vector);
1162
+ }
1163
+ vec_xst(t5, 0, vecOffset);
1164
+ vec_xst(t6, 0, vecOffset+16);
1165
+ vec_xst(t7, 0, vecOffset+32);
1166
+ vec_xst(t8, 0, vecOffset+48);
1167
+
1168
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1169
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1170
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1171
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1172
+ t5 = vec_perm(t1, t3, swiz3);
1173
+ t6 = vec_perm(t1, t3, swiz4);
1174
+ t7 = vec_perm(t2, t4, swiz3);
1175
+ t8 = vec_perm(t2, t4, swiz4);
1176
+ if (flip == true) {
1177
+ t5 = vec_xor(t5, xor_vector);
1178
+ t6 = vec_xor(t6, xor_vector);
1179
+ t7 = vec_xor(t7, xor_vector);
1180
+ t8 = vec_xor(t8, xor_vector);
1181
+ }
1182
+ vec_xst(t5, 0, vecOffset+64);
1183
+ vec_xst(t6, 0, vecOffset+80);
1184
+ vec_xst(t7, 0, vecOffset+96);
1185
+ vec_xst(t8, 0, vecOffset+112);
1186
+
1187
+ t1 = vec_perm(c5[0], c6[0], swiz1);
1188
+ t2 = vec_perm(c5[0], c6[0], swiz2);
1189
+ t3 = vec_perm(c7[0], c8[0], swiz1);
1190
+ t4 = vec_perm(c7[0], c8[0], swiz2);
1191
+ t5 = vec_perm(t1, t3, swiz3);
1192
+ t6 = vec_perm(t1, t3, swiz4);
1193
+ t7 = vec_perm(t2, t4, swiz3);
1194
+ t8 = vec_perm(t2, t4, swiz4);
1195
+ if (flip == true) {
1196
+ t5 = vec_xor(t5, xor_vector);
1197
+ t6 = vec_xor(t6, xor_vector);
1198
+ t7 = vec_xor(t7, xor_vector);
1199
+ t8 = vec_xor(t8, xor_vector);
1200
+ }
1201
+ vec_xst(t5, 0, vecOffset+128);
1202
+ vec_xst(t6, 0, vecOffset+144);
1203
+ vec_xst(t7, 0, vecOffset+160);
1204
+ vec_xst(t8, 0, vecOffset+176);
1205
+
1206
+ t1 = vec_perm(c5[1], c6[1], swiz1);
1207
+ t2 = vec_perm(c5[1], c6[1], swiz2);
1208
+ t3 = vec_perm(c7[1], c8[1], swiz1);
1209
+ t4 = vec_perm(c7[1], c8[1], swiz2);
1210
+ t5 = vec_perm(t1, t3, swiz3);
1211
+ t6 = vec_perm(t1, t3, swiz4);
1212
+ t7 = vec_perm(t2, t4, swiz3);
1213
+ t8 = vec_perm(t2, t4, swiz4);
1214
+ if (flip == true) {
1215
+ t5 = vec_xor(t5, xor_vector);
1216
+ t6 = vec_xor(t6, xor_vector);
1217
+ t7 = vec_xor(t7, xor_vector);
1218
+ t8 = vec_xor(t8, xor_vector);
1219
+ }
1220
+ vec_xst(t5, 0, vecOffset+192);
1221
+ vec_xst(t6, 0, vecOffset+208);
1222
+ vec_xst(t7, 0, vecOffset+224);
1223
+ vec_xst(t8, 0, vecOffset+240);
1224
+
1225
+ aoffset1 += lda;
1226
+ aoffset2 += lda;
1227
+ aoffset3 += lda;
1228
+ aoffset4 += lda;
1229
+ aoffset5 += lda;
1230
+ aoffset6 += lda;
1231
+ aoffset7 += lda;
1232
+ aoffset8 += lda;
1233
+ vecOffset += 256;
1234
+ i--;
1235
+ } while(i > 0);
1236
+ }
1237
+ j--;
1238
+ } while(j > 0);
1239
+ }
1240
+
1241
+ if (rows & 4) {
1242
+ aoffset1 = aoffset;
1243
+ aoffset2 = aoffset1 + lda;
1244
+ aoffset3 = aoffset2 + lda;
1245
+ aoffset4 = aoffset3 + lda;
1246
+ aoffset += 4 * lda;
1247
+
1248
+ i = (cols >> 3);
1249
+ if (i > 0) {
1250
+ do {
1251
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1252
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1253
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1254
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
1255
+
1256
+ __builtin_vsx_disassemble_pair(c1, &C1);
1257
+ __builtin_vsx_disassemble_pair(c2, &C2);
1258
+ __builtin_vsx_disassemble_pair(c3, &C3);
1259
+ __builtin_vsx_disassemble_pair(c4, &C4);
1260
+
1261
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1262
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1263
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1264
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1265
+ t5 = vec_perm(t1, t3, swiz3);
1266
+ t6 = vec_perm(t1, t3, swiz4);
1267
+ t7 = vec_perm(t2, t4, swiz3);
1268
+ t8 = vec_perm(t2, t4, swiz4);
1269
+ if (flip == true) {
1270
+ t5 = vec_xor(t5, xor_vector);
1271
+ t6 = vec_xor(t6, xor_vector);
1272
+ t7 = vec_xor(t7, xor_vector);
1273
+ t8 = vec_xor(t8, xor_vector);
1274
+ }
1275
+ vec_xst(t5, 0, vecOffset);
1276
+ vec_xst(t6, 0, vecOffset+16);
1277
+ vec_xst(t7, 0, vecOffset+32);
1278
+ vec_xst(t8, 0, vecOffset+48);
1279
+
1280
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1281
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1282
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1283
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1284
+ t5 = vec_perm(t1, t3, swiz3);
1285
+ t6 = vec_perm(t1, t3, swiz4);
1286
+ t7 = vec_perm(t2, t4, swiz3);
1287
+ t8 = vec_perm(t2, t4, swiz4);
1288
+ if (flip == true) {
1289
+ t5 = vec_xor(t5, xor_vector);
1290
+ t6 = vec_xor(t6, xor_vector);
1291
+ t7 = vec_xor(t7, xor_vector);
1292
+ t8 = vec_xor(t8, xor_vector);
1293
+ }
1294
+ vec_xst(t5, 0, vecOffset+64);
1295
+ vec_xst(t6, 0, vecOffset+80);
1296
+ vec_xst(t7, 0, vecOffset+96);
1297
+ vec_xst(t8, 0, vecOffset+112);
1298
+
1299
+ aoffset1 += lda;
1300
+ aoffset2 += lda;
1301
+ aoffset3 += lda;
1302
+ aoffset4 += lda;
1303
+ vecOffset += 128;
1304
+ i--;
1305
+ } while(i > 0);
1306
+ }
1307
+ }
1308
+ if (rows & 3) {
1309
+ aoffset1 = aoffset;
1310
+ aoffset2 = aoffset1 + lda;
1311
+ aoffset3 = aoffset2 + lda;
1312
+ i = (cols >> 3);
1313
+ if (i > 0) {
1314
+ do {
1315
+ switch(rows) {
1316
+ case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1317
+ __builtin_vsx_disassemble_pair(c3, &C3);
1318
+ case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1319
+ __builtin_vsx_disassemble_pair(c2, &C2);
1320
+ case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1321
+ __builtin_vsx_disassemble_pair(c1, &C1);
1322
+ break;
1323
+ }
1324
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1325
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1326
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1327
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1328
+ t5 = vec_perm(t1, t3, swiz3);
1329
+ t6 = vec_perm(t1, t3, swiz4);
1330
+ t7 = vec_perm(t2, t4, swiz3);
1331
+ t8 = vec_perm(t2, t4, swiz4);
1332
+ if (flip == true) {
1333
+ t5 = vec_xor(t5, xor_vector);
1334
+ t6 = vec_xor(t6, xor_vector);
1335
+ t7 = vec_xor(t7, xor_vector);
1336
+ t8 = vec_xor(t8, xor_vector);
1337
+ }
1338
+ vec_xst(t5, 0, vecOffset);
1339
+ vec_xst(t6, 0, vecOffset+16);
1340
+ vec_xst(t7, 0, vecOffset+32);
1341
+ vec_xst(t8, 0, vecOffset+48);
1342
+
1343
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1344
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1345
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1346
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1347
+ t5 = vec_perm(t1, t3, swiz3);
1348
+ t6 = vec_perm(t1, t3, swiz4);
1349
+ t7 = vec_perm(t2, t4, swiz3);
1350
+ t8 = vec_perm(t2, t4, swiz4);
1351
+ if (flip == true) {
1352
+ t5 = vec_xor(t5, xor_vector);
1353
+ t6 = vec_xor(t6, xor_vector);
1354
+ t7 = vec_xor(t7, xor_vector);
1355
+ t8 = vec_xor(t8, xor_vector);
1356
+ }
1357
+ vec_xst(t5, 0, vecOffset+64);
1358
+ vec_xst(t6, 0, vecOffset+80);
1359
+ vec_xst(t7, 0, vecOffset+96);
1360
+ vec_xst(t8, 0, vecOffset+112);
1361
+
1362
+ aoffset1 += lda;
1363
+ aoffset2 += lda;
1364
+ aoffset3 += lda;
1365
+ vecOffset += 128;
1366
+ i--;
1367
+ } while(i > 0);
1368
+ }
1369
+ }
1370
+ }
1371
+
1372
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1373
+ int64_t mc, nc, mp, np;
1374
+ int m_rem = MIN(m - m0, 8);
1375
+ int n_rem = MIN(n - n0, 8);
1376
+ // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
1377
+ // issues. After resolving them, below code will be enabled.
1378
+ /*if (m_rem >= 16 && n_rem >= 8) {
1379
+ mc = 16;
1380
+ nc = 8;
1381
+ gemm<16,8>(m0, m, n0, n);
1382
+ } else if(m_rem >= 8 && n_rem >= 16) {
1383
+ mc = 8;
1384
+ nc = 16;
1385
+ gemm<8,16>(m0, m, n0, n);
1386
+ }*/
1387
+ if (m_rem >= 8 && n_rem >= 8) {
1388
+ mc = 8;
1389
+ nc = 8;
1390
+ gemm<8,8>(m0, m, n0, n);
1391
+ } else if (m_rem >= 4 && n_rem >= 8) {
1392
+ mc = 4;
1393
+ nc = 8;
1394
+ gemm<4,8>(m0, m, n0, n);
1395
+ } else if (m_rem >= 8 && n_rem >= 4) {
1396
+ mc = 8;
1397
+ nc = 4;
1398
+ gemm<8,4>(m0, m, n0, n);
1399
+ } else if (m_rem >= 4 && n_rem >= 4) {
1400
+ mc = 4;
1401
+ nc = 4;
1402
+ gemm_small<4, 4>(m0, m, n0, n);
1403
+ } else if ((m_rem < 4) && (n_rem > 4)) {
1404
+ nc = 4;
1405
+ switch(m_rem) {
1406
+ case 1:
1407
+ mc = 1;
1408
+ gemm_small<1, 4>(m0, m, n0, n);
1409
+ break;
1410
+ case 2:
1411
+ mc = 2;
1412
+ gemm_small<2, 4>(m0, m, n0, n);
1413
+ break;
1414
+ case 3:
1415
+ mc = 3;
1416
+ gemm_small<3, 4>(m0, m, n0, n);
1417
+ break;
1418
+ default:
1419
+ return;
1420
+ }
1421
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1422
+ mc = 4;
1423
+ switch(n_rem) {
1424
+ case 1:
1425
+ nc = 1;
1426
+ gemm_small<4, 1>(m0, m, n0, n);
1427
+ break;
1428
+ case 2:
1429
+ nc = 2;
1430
+ gemm_small<4, 2>(m0, m, n0, n);
1431
+ break;
1432
+ case 3:
1433
+ nc = 3;
1434
+ gemm_small<4, 3>(m0, m, n0, n);
1435
+ break;
1436
+ default:
1437
+ return;
1438
+ }
1439
+ } else {
1440
+ switch((m_rem << 4) | n_rem) {
1441
+ case 0x43:
1442
+ mc = 4;
1443
+ nc = 3;
1444
+ gemm_small<4, 3>(m0, m, n0, n);
1445
+ break;
1446
+ case 0x42:
1447
+ mc = 4;
1448
+ nc = 2;
1449
+ gemm_small<4, 2>(m0, m, n0, n);
1450
+ break;
1451
+ case 0x41:
1452
+ mc = 4;
1453
+ nc = 1;
1454
+ gemm_small<4, 1>(m0, m, n0, n);
1455
+ break;
1456
+ case 0x34:
1457
+ mc = 3;
1458
+ nc = 4;
1459
+ gemm_small<3, 4>(m0, m, n0, n);
1460
+ break;
1461
+ case 0x33:
1462
+ mc = 3;
1463
+ nc = 3;
1464
+ gemm_small<3, 3>(m0, m, n0, n);
1465
+ break;
1466
+ case 0x32:
1467
+ mc = 3;
1468
+ nc = 2;
1469
+ gemm_small<3, 2>(m0, m, n0, n);
1470
+ break;
1471
+ case 0x31:
1472
+ mc = 3;
1473
+ nc = 1;
1474
+ gemm_small<3, 1>(m0, m, n0, n);
1475
+ break;
1476
+ case 0x24:
1477
+ mc = 2;
1478
+ nc = 4;
1479
+ gemm_small<2, 4>(m0, m, n0, n);
1480
+ break;
1481
+ case 0x23:
1482
+ mc = 2;
1483
+ nc = 3;
1484
+ gemm_small<2, 3>(m0, m, n0, n);
1485
+ break;
1486
+ case 0x22:
1487
+ mc = 2;
1488
+ nc = 2;
1489
+ gemm_small<2, 2>(m0, m, n0, n);
1490
+ break;
1491
+ case 0x21:
1492
+ mc = 2;
1493
+ nc = 1;
1494
+ gemm_small<2, 1>(m0, m, n0, n);
1495
+ break;
1496
+ case 0x14:
1497
+ mc = 1;
1498
+ nc = 4;
1499
+ gemm_small<1, 4>(m0, m, n0, n);
1500
+ break;
1501
+ case 0x13:
1502
+ mc = 1;
1503
+ nc = 3;
1504
+ gemm_small<1, 3>(m0, m, n0, n);
1505
+ break;
1506
+ case 0x12:
1507
+ mc = 1;
1508
+ nc = 2;
1509
+ gemm_small<1, 2>(m0, m, n0, n);
1510
+ break;
1511
+ case 0x11:
1512
+ mc = 1;
1513
+ nc = 1;
1514
+ gemm_small<1, 1>(m0, m, n0, n);
1515
+ break;
1516
+ default:
1517
+ return;
1518
+ }
1519
+ }
1520
+ mp = m0 + (m - m0) / mc * mc;
1521
+ np = n0 + (n - n0) / nc * nc;
1522
+ mnpack(mp, m, n0, np);
1523
+ mnpack(m0, m, np, n);
1524
+ }
1525
+
1526
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1527
+ vec_t vec_A[8], vec_B[16] = {0};
1528
+ acc_t acc_0, acc_1;
1529
+ std::array<int, 4> comparray;
1530
+ vector float fin_res[8] = {0};
1531
+ vector float vs[8] = {0};
1532
+ for (int l = 0; l < k; l++) {
1533
+ __builtin_mma_xxsetaccz(&acc_0);
1534
+ __builtin_mma_xxsetaccz(&acc_1);
1535
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1536
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1537
+ for(int x = 0; x < 8; x++) {
1538
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1539
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
1540
+ }
1541
+ for (int I = 0; I<4; I++) {
1542
+ for (int J = 0; J<4; J++) {
1543
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1544
+ *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1545
+ }
1546
+ }
1547
+ auto aoffset = A+(ii*lda)+l;
1548
+ for (int i = 0; i < 4; i++) {
1549
+ comparray[i] = 0;
1550
+ int ca = 0;
1551
+ const int8_t *at = aoffset->qs;
1552
+ for (int j = 0; j < 32; j++)
1553
+ ca += (int)*at++;
1554
+ comparray[i] = ca;
1555
+ aoffset += lda;
1556
+ }
1557
+ compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1558
+ compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
1559
+ }
1560
+ save_res<4, 4>(ii, jj, 0, fin_res);
1561
+ save_res<4, 4>(ii, jj+4, 4, fin_res);
1562
+ }
1563
+
1564
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1565
+ vec_t vec_A[16], vec_B[8] = {0};
1566
+ acc_t acc_0, acc_1;
1567
+ std::array<int, 8> comparray;
1568
+ vector float fin_res[8] = {0};
1569
+ vector float vs[8] = {0};
1570
+ for (int l = 0; l < k; l++) {
1571
+ __builtin_mma_xxsetaccz(&acc_0);
1572
+ __builtin_mma_xxsetaccz(&acc_1);
1573
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1574
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1575
+ for(int x = 0; x < 8; x++) {
1576
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1577
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1578
+ }
1579
+ for (int I = 0; I<8; I++) {
1580
+ for (int J = 0; J<4; J++) {
1581
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1582
+ }
1583
+ }
1584
+ auto aoffset = A+(ii*lda)+l;
1585
+ for (int i = 0; i < 8; i++) {
1586
+ comparray[i] = 0;
1587
+ int ca = 0;
1588
+ const int8_t *at = aoffset->qs;
1589
+ for (int j = 0; j < 32; j++)
1590
+ ca += (int)*at++;
1591
+ comparray[i] = ca;
1592
+ aoffset += lda;
1593
+ }
1594
+ compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1595
+ compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
1596
+ }
1597
+ save_res<4, 4>(ii, jj, 0, fin_res);
1598
+ save_res<4, 4>(ii+4, jj, 4, fin_res);
1599
+ }
1600
+
1601
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1602
+ vec_t vec_A[16], vec_B[16] = {0};
1603
+ acc_t acc_0, acc_1, acc_2, acc_3;
1604
+ std::array<int, 8> comparray;
1605
+ vector float fin_res[16] = {0};
1606
+ vector float vs[16] = {0};
1607
+ for (int l = 0; l < k; l++) {
1608
+ __builtin_mma_xxsetaccz(&acc_0);
1609
+ __builtin_mma_xxsetaccz(&acc_1);
1610
+ __builtin_mma_xxsetaccz(&acc_2);
1611
+ __builtin_mma_xxsetaccz(&acc_3);
1612
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1613
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1614
+ for(int x = 0; x < 8; x++) {
1615
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1616
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1617
+ __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
1618
+ __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
1619
+ }
1620
+ for (int I = 0; I<8; I++) {
1621
+ for (int J = 0; J<4; J++) {
1622
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1623
+ *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1624
+ }
1625
+ }
1626
+ auto aoffset = A+(ii*lda)+l;
1627
+ for (int i = 0; i < 8; i++) {
1628
+ comparray[i] = 0;
1629
+ int ca = 0;
1630
+ const int8_t *at = aoffset->qs;
1631
+ for (int j = 0; j < 32; j++)
1632
+ ca += (int)*at++;
1633
+ comparray[i] = ca;
1634
+ aoffset += lda;
1635
+ }
1636
+ compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1637
+ compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
1638
+ compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
1639
+ compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
1640
+ }
1641
+ save_res<4, 4>(ii, jj, 0, fin_res);
1642
+ save_res<4, 4>(ii+4, jj, 4, fin_res);
1643
+ save_res<4, 4>(ii, jj+4, 8, fin_res);
1644
+ save_res<4, 4>(ii+4, jj+4, 12, fin_res);
1645
+ }
1646
+
1647
+ template<int RM, int RN>
1648
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1649
+ int64_t ytiles = (m - m0) / RM;
1650
+ int64_t xtiles = (n - n0) / RN;
1651
+ int64_t tiles = xtiles * ytiles;
1652
+ int64_t duty = (tiles + nth - 1) / nth;
1653
+ int64_t start = duty * ith;
1654
+ int64_t end = start + duty;
1655
+ vec_t vec_A[8], vec_B[8] = {0};
1656
+ vector signed int vec_C[4];
1657
+ acc_t acc_0;
1658
+
1659
+ if (end > tiles)
1660
+ end = tiles;
1661
+ for (int64_t job = start; job < end; ++job) {
1662
+ int64_t ii = m0 + job / xtiles * RM;
1663
+ int64_t jj = n0 + job % xtiles * RN;
1664
+ std::array<int, RM> comparray;
1665
+ vector float res[4] = {0};
1666
+ vector float fin_res[4] = {0};
1667
+ vector float vs[4] = {0};
1668
+ vector float CA[4] = {0};
1669
+ __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
1670
+ __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
1671
+ for (int l = 0; l < k; l++) {
1672
+ __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1673
+ __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1674
+ __builtin_mma_xxsetaccz(&acc_0);
1675
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
1676
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
1677
+ for(int x = 0; x < 8; x+=4) {
1678
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1679
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
1680
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
1681
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
1682
+ }
1683
+ for (int I = 0; I<RM; I++) {
1684
+ for (int J = 0; J<RN; J++) {
1685
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1686
+ }
1687
+ }
1688
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1689
+ auto aoffset = A+(ii*lda)+l;
1690
+ for (int i = 0; i < RM; i++) {
1691
+ comparray[i] = 0;
1692
+ int ca = 0;
1693
+ const int8_t *at = aoffset->qs;
1694
+ for (int j = 0; j < 32; j++)
1695
+ ca += (int)*at++;
1696
+ comparray[i] = ca;
1697
+ aoffset += lda;
1698
+ }
1699
+
1700
+ for (int i = 0; i < RM; i++) {
1701
+ CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
1702
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1703
+ fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
1704
+ }
1705
+ }
1706
+ save_res<RM, RN>(ii, jj, 0, fin_res);
1707
+ }
1708
+ }
1709
+
1710
+ template<int RM, int RN>
1711
+ inline void kernel(int64_t ii, int64_t jj) {
1712
+ if constexpr(RM == 4 && RN == 8) {
1713
+ KERNEL_4x8(ii,jj);
1714
+ } else if constexpr(RM == 8 && RN == 4) {
1715
+ KERNEL_8x4(ii,jj);
1716
+ } else if constexpr(RM == 8 && RN == 8) {
1717
+ KERNEL_8x8(ii,jj);
1718
+ } else {
1719
+ static_assert(false, "RN/RM values not supported");
1720
+ }
1721
+ }
1722
+
1723
+ template <int RM, int RN>
1724
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1725
+ int64_t ytiles = (m - m0) / RM;
1726
+ int64_t xtiles = (n - n0) / RN;
1727
+ int64_t tiles = xtiles * ytiles;
1728
+ int64_t duty = (tiles + nth - 1) / nth;
1729
+ int64_t start = duty * ith;
1730
+ int64_t end = start + duty;
1731
+ if (end > tiles)
1732
+ end = tiles;
1733
+ for (int64_t job = start; job < end; ++job) {
1734
+ int64_t ii = m0 + job / xtiles * RM;
1735
+ int64_t jj = n0 + job % xtiles * RN;
1736
+ kernel<RM, RN>(ii, jj);
1737
+ }
1738
+ }
1739
+
1740
+ const TA *const A;
1741
+ const TB *const B;
1742
+ TC *C;
1743
+ TA *At;
1744
+ TB *Bt;
1745
+ const int64_t k;
1746
+ const int64_t lda;
1747
+ const int64_t ldb;
1748
+ const int64_t ldc;
1749
+ const int ith;
1750
+ const int nth;
1751
+ };
1752
+
1052
1753
  template <typename TA, typename TB, typename TC>
1053
1754
  class tinyBLAS_PPC {
1054
1755
  public:
@@ -1068,13 +1769,17 @@ class tinyBLAS_PPC {
1068
1769
 
1069
1770
  void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1070
1771
 
1071
- void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
1772
+ template<typename VA>
1773
+ void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
1072
1774
  int64_t i, j;
1073
- float *aoffset = NULL, *boffset = NULL;
1074
- float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1075
- float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1076
-
1077
- aoffset = const_cast<float*>(a);
1775
+ TA *aoffset = NULL, *boffset = NULL;
1776
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1777
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1778
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1779
+ VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1780
+ VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1781
+ VA t1, t2, t3, t4, t5, t6, t7, t8;
1782
+ aoffset = const_cast<TA*>(a);
1078
1783
  boffset = vec;
1079
1784
  j = (rows >> 3);
1080
1785
  if (j > 0) {
@@ -1090,9 +1795,6 @@ class tinyBLAS_PPC {
1090
1795
  aoffset += 8 * lda;
1091
1796
  i = (cols >> 3);
1092
1797
  if (i > 0) {
1093
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1094
- vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
1095
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1096
1798
  do {
1097
1799
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1098
1800
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1172,21 +1874,19 @@ class tinyBLAS_PPC {
1172
1874
  } while(i > 0);
1173
1875
  }
1174
1876
  if (cols & 4) {
1175
- vector float c1, c2, c3, c4, c5, c6, c7, c8;
1176
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1177
- c1 = vec_xl(0, aoffset1);
1178
- c2 = vec_xl(0, aoffset2);
1179
- c3 = vec_xl(0, aoffset3);
1180
- c4 = vec_xl(0, aoffset4);
1181
- c5 = vec_xl(0, aoffset5);
1182
- c6 = vec_xl(0, aoffset6);
1183
- c7 = vec_xl(0, aoffset7);
1184
- c8 = vec_xl(0, aoffset8);
1185
-
1186
- t1 = vec_mergeh(c1, c2);
1187
- t2 = vec_mergeh(c3, c4);
1188
- t3 = vec_mergeh(c5, c6);
1189
- t4 = vec_mergeh(c7, c8);
1877
+ c1[0] = vec_xl(0, aoffset1);
1878
+ c2[0] = vec_xl(0, aoffset2);
1879
+ c3[0] = vec_xl(0, aoffset3);
1880
+ c4[0] = vec_xl(0, aoffset4);
1881
+ c5[0] = vec_xl(0, aoffset5);
1882
+ c6[0] = vec_xl(0, aoffset6);
1883
+ c7[0] = vec_xl(0, aoffset7);
1884
+ c8[0] = vec_xl(0, aoffset8);
1885
+
1886
+ t1 = vec_mergeh(c1[0], c2[0]);
1887
+ t2 = vec_mergeh(c3[0], c4[0]);
1888
+ t3 = vec_mergeh(c5[0], c6[0]);
1889
+ t4 = vec_mergeh(c7[0], c8[0]);
1190
1890
  t5 = vec_xxpermdi(t1, t2, 0);
1191
1891
  t6 = vec_xxpermdi(t3, t4, 0);
1192
1892
  t7 = vec_xxpermdi(t1, t2, 3);
@@ -1196,10 +1896,10 @@ class tinyBLAS_PPC {
1196
1896
  vec_xst(t7, 0, boffset+8);
1197
1897
  vec_xst(t8, 0, boffset+12);
1198
1898
 
1199
- t1 = vec_mergel(c1, c2);
1200
- t2 = vec_mergel(c3, c4);
1201
- t3 = vec_mergel(c5, c6);
1202
- t4 = vec_mergel(c7, c8);
1899
+ t1 = vec_mergel(c1[0], c2[0]);
1900
+ t2 = vec_mergel(c3[0], c4[0]);
1901
+ t3 = vec_mergel(c5[0], c6[0]);
1902
+ t4 = vec_mergel(c7[0], c8[0]);
1203
1903
  t5 = vec_xxpermdi(t1, t2, 0);
1204
1904
  t6 = vec_xxpermdi(t3, t4, 0);
1205
1905
  t7 = vec_xxpermdi(t1, t2, 3);
@@ -1221,9 +1921,6 @@ class tinyBLAS_PPC {
1221
1921
  aoffset += 4 * lda;
1222
1922
  i = (cols >> 3);
1223
1923
  if (i > 0) {
1224
- __vector_pair C1, C2, C3, C4;
1225
- vector float c1[2], c2[2], c3[2], c4[2];
1226
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1227
1924
  do {
1228
1925
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1229
1926
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1270,22 +1967,20 @@ class tinyBLAS_PPC {
1270
1967
  }
1271
1968
 
1272
1969
  if (cols & 4) {
1273
- vector float c1, c2, c3, c4;
1274
- vector float t1, t2, t3, t4;
1275
- c1 = vec_xl(0, aoffset1);
1276
- c2 = vec_xl(0, aoffset2);
1277
- c3 = vec_xl(0, aoffset3);
1278
- c4 = vec_xl(0, aoffset4);
1279
-
1280
- t1 = vec_mergeh(c1, c2);
1281
- t2 = vec_mergeh(c3, c4);
1970
+ c1[0] = vec_xl(0, aoffset1);
1971
+ c2[0] = vec_xl(0, aoffset2);
1972
+ c3[0] = vec_xl(0, aoffset3);
1973
+ c4[0] = vec_xl(0, aoffset4);
1974
+
1975
+ t1 = vec_mergeh(c1[0], c2[0]);
1976
+ t2 = vec_mergeh(c3[0], c4[0]);
1282
1977
  t3 = vec_xxpermdi(t1, t2, 0);
1283
1978
  t4 = vec_xxpermdi(t1, t2, 3);
1284
1979
  vec_xst(t3, 0, boffset);
1285
1980
  vec_xst(t4, 0, boffset+4);
1286
1981
 
1287
- t1 = vec_mergel(c1, c2);
1288
- t2 = vec_mergel(c3, c4);
1982
+ t1 = vec_mergel(c1[0], c2[0]);
1983
+ t2 = vec_mergel(c3[0], c4[0]);
1289
1984
  t3 = vec_xxpermdi(t1, t2, 0);
1290
1985
  t4 = vec_xxpermdi(t1, t2, 3);
1291
1986
  vec_xst(t3, 0, boffset+8);
@@ -1297,21 +1992,19 @@ class tinyBLAS_PPC {
1297
1992
  aoffset2 = aoffset1 + lda;
1298
1993
  aoffset3 = aoffset2 + lda;
1299
1994
  if (cols & 4) {
1300
- vector float c1, c2, c3, c4 = {0};
1301
- vector float t1, t2, t3, t4;
1302
- c1 = vec_xl(0, aoffset1);
1303
- c2 = vec_xl(0, aoffset2);
1304
- c3 = vec_xl(0, aoffset3);
1305
-
1306
- t1 = vec_mergeh(c1, c2);
1307
- t2 = vec_mergeh(c3, c4);
1995
+ c1[0] = vec_xl(0, aoffset1);
1996
+ c2[0] = vec_xl(0, aoffset2);
1997
+ c3[0] = vec_xl(0, aoffset3);
1998
+
1999
+ t1 = vec_mergeh(c1[0], c2[0]);
2000
+ t2 = vec_mergeh(c3[0], c4[0]);
1308
2001
  t3 = vec_xxpermdi(t1, t2, 0);
1309
2002
  t4 = vec_xxpermdi(t1, t2, 3);
1310
2003
  vec_xst(t3, 0, boffset);
1311
2004
  vec_xst(t4, 0, boffset+4);
1312
2005
 
1313
- t1 = vec_mergel(c1, c2);
1314
- t2 = vec_mergel(c3, c4);
2006
+ t1 = vec_mergel(c1[0], c2[0]);
2007
+ t2 = vec_mergel(c3[0], c4[0]);
1315
2008
  t3 = vec_xxpermdi(t1, t2, 0);
1316
2009
  t4 = vec_xxpermdi(t1, t2, 3);
1317
2010
  vec_xst(t3, 0, boffset+8);
@@ -1319,14 +2012,13 @@ class tinyBLAS_PPC {
1319
2012
  }
1320
2013
  }
1321
2014
  }
1322
-
1323
2015
  void KERNEL_4x4(int64_t ii, int64_t jj) {
1324
2016
  vec_t vec_A[4], vec_B[4], vec_C[4];
1325
2017
  acc_t acc_0;
1326
2018
  __builtin_mma_xxsetaccz(&acc_0);
1327
2019
  for (int l = 0; l < k; l+=4) {
1328
- READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1329
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2020
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2021
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
1330
2022
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1331
2023
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1332
2024
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -1341,8 +2033,8 @@ class tinyBLAS_PPC {
1341
2033
  __builtin_mma_xxsetaccz(&acc_0);
1342
2034
  __builtin_mma_xxsetaccz(&acc_1);
1343
2035
  for (int64_t l = 0; l < k; l+=4) {
1344
- READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1345
- READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
2036
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2037
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
1346
2038
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
1347
2039
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
1348
2040
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -1362,8 +2054,8 @@ class tinyBLAS_PPC {
1362
2054
  __builtin_mma_xxsetaccz(&acc_0);
1363
2055
  __builtin_mma_xxsetaccz(&acc_1);
1364
2056
  for (int64_t l = 0; l < k; l+=4) {
1365
- READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
1366
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2057
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2058
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
1367
2059
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
1368
2060
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
1369
2061
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -1385,8 +2077,8 @@ class tinyBLAS_PPC {
1385
2077
  __builtin_mma_xxsetaccz(&acc_2);
1386
2078
  __builtin_mma_xxsetaccz(&acc_3);
1387
2079
  for (int l = 0; l < k; l+=8) {
1388
- READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
1389
- READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
2080
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
2081
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
1390
2082
  for(int x = 0; x < 16; x+=2) {
1391
2083
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
1392
2084
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -1569,15 +2261,15 @@ class tinyBLAS_PPC {
1569
2261
  vec_t vec_A[4], vec_B[4];
1570
2262
  for (int l=0; l<k; l+=4) {
1571
2263
  if (RN >= 4 && RM == 1) {
1572
- float* a = const_cast<float*>(A+(ii)*lda+l);
1573
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2264
+ TA* a = const_cast<TA*>(A+(ii)*lda+l);
2265
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
1574
2266
  vec_A[0] = (vec_t)vec_xl(0,a);
1575
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
1576
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
1577
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
2267
+ vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
2268
+ vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
2269
+ vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
1578
2270
  } else {
1579
- READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
1580
- READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
2271
+ packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2272
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
1581
2273
  }
1582
2274
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1583
2275
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -1587,7 +2279,7 @@ class tinyBLAS_PPC {
1587
2279
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
1588
2280
  for (int I = 0; I < RM; I++) {
1589
2281
  for (int J = 0; J < RN; J++) {
1590
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
2282
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1591
2283
  }
1592
2284
  }
1593
2285
  }
@@ -1810,6 +2502,20 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
1810
2502
  params->ith, params->nth};
1811
2503
  tb.matmul(m, n);
1812
2504
  return true;
2505
+
2506
+ #elif defined(__MMA__)
2507
+ if (n < 8 && n != 4)
2508
+ return false;
2509
+ if (m < 8 && m != 4)
2510
+ return false;
2511
+ tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
2512
+ k, (const block_q8_0 *)A, lda,
2513
+ (const block_q8_0 *)B, ldb,
2514
+ (float *)C, ldc,
2515
+ params->ith, params->nth};
2516
+ tb.matmul(m, n);
2517
+ return true;
2518
+
1813
2519
  #else
1814
2520
  return false;
1815
2521
  #endif