cui-llama.rn 1.2.6 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/README.md +3 -2
  2. package/android/src/main/CMakeLists.txt +26 -6
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +228 -40
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/amx/amx.cpp +196 -0
  9. package/cpp/amx/amx.h +20 -0
  10. package/cpp/amx/common.h +101 -0
  11. package/cpp/amx/mmq.cpp +2524 -0
  12. package/cpp/amx/mmq.h +16 -0
  13. package/cpp/common.cpp +118 -251
  14. package/cpp/common.h +53 -30
  15. package/cpp/ggml-aarch64.c +46 -3395
  16. package/cpp/ggml-aarch64.h +0 -20
  17. package/cpp/ggml-alloc.c +6 -8
  18. package/cpp/ggml-backend-impl.h +33 -11
  19. package/cpp/ggml-backend-reg.cpp +423 -0
  20. package/cpp/ggml-backend.cpp +14 -676
  21. package/cpp/ggml-backend.h +46 -9
  22. package/cpp/ggml-common.h +6 -0
  23. package/cpp/ggml-cpu-aarch64.c +3823 -0
  24. package/cpp/ggml-cpu-aarch64.h +32 -0
  25. package/cpp/ggml-cpu-impl.h +14 -242
  26. package/cpp/ggml-cpu-quants.c +10835 -0
  27. package/cpp/ggml-cpu-quants.h +63 -0
  28. package/cpp/ggml-cpu.c +13971 -13720
  29. package/cpp/ggml-cpu.cpp +715 -0
  30. package/cpp/ggml-cpu.h +65 -63
  31. package/cpp/ggml-impl.h +285 -25
  32. package/cpp/ggml-metal.h +8 -8
  33. package/cpp/ggml-metal.m +1221 -728
  34. package/cpp/ggml-quants.c +189 -10681
  35. package/cpp/ggml-quants.h +78 -125
  36. package/cpp/ggml-threading.cpp +12 -0
  37. package/cpp/ggml-threading.h +12 -0
  38. package/cpp/ggml.c +688 -1460
  39. package/cpp/ggml.h +58 -244
  40. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  41. package/cpp/json.hpp +24766 -24766
  42. package/cpp/llama-sampling.cpp +5 -2
  43. package/cpp/llama.cpp +409 -123
  44. package/cpp/llama.h +8 -4
  45. package/cpp/rn-llama.hpp +89 -25
  46. package/cpp/sampling.cpp +42 -3
  47. package/cpp/sampling.h +22 -1
  48. package/cpp/sgemm.cpp +608 -0
  49. package/cpp/speculative.cpp +270 -0
  50. package/cpp/speculative.h +28 -0
  51. package/cpp/unicode.cpp +11 -0
  52. package/ios/RNLlama.mm +43 -20
  53. package/ios/RNLlamaContext.h +9 -3
  54. package/ios/RNLlamaContext.mm +146 -33
  55. package/jest/mock.js +0 -1
  56. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  57. package/lib/commonjs/grammar.js +4 -2
  58. package/lib/commonjs/grammar.js.map +1 -1
  59. package/lib/commonjs/index.js +52 -15
  60. package/lib/commonjs/index.js.map +1 -1
  61. package/lib/module/NativeRNLlama.js.map +1 -1
  62. package/lib/module/grammar.js +2 -1
  63. package/lib/module/grammar.js.map +1 -1
  64. package/lib/module/index.js +51 -15
  65. package/lib/module/index.js.map +1 -1
  66. package/lib/typescript/NativeRNLlama.d.ts +122 -8
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  68. package/lib/typescript/grammar.d.ts +5 -6
  69. package/lib/typescript/grammar.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +15 -6
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +2 -1
  73. package/src/NativeRNLlama.ts +135 -13
  74. package/src/grammar.ts +10 -8
  75. package/src/index.ts +104 -28
package/cpp/sgemm.cpp CHANGED
@@ -106,6 +106,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
106
106
  inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
107
107
  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
108
108
 
109
+ #if defined(__MMA__)
110
+ typedef vector unsigned char vec_t;
111
+ typedef __vector_quad acc_t;
112
+ #endif
109
113
  ////////////////////////////////////////////////////////////////////////////////////////////////////
110
114
  // VECTORIZED FUSED MULTIPLY ADD
111
115
 
@@ -1026,6 +1030,600 @@ class tinyBLAS_Q0_AVX {
1026
1030
  };
1027
1031
  #endif // __AVX__
1028
1032
 
1033
+ //PPC Implementation
1034
+ #if defined(__MMA__)
1035
+
1036
+ #define SAVE_ACC(ACC, ii, jj) \
1037
+ __builtin_mma_disassemble_acc(vec_C, ACC); \
1038
+ for (int I = 0; I < 4; I++) { \
1039
+ for (int J = 0; J < 4; J++) { \
1040
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1041
+ } \
1042
+ } \
1043
+
1044
+ template <typename TA, typename TB, typename TC>
1045
+ class tinyBLAS_PPC {
1046
+ public:
1047
+ tinyBLAS_PPC(int64_t k,
1048
+ const TA *A, int64_t lda,
1049
+ const TB *B, int64_t ldb,
1050
+ TC *C, int64_t ldc,
1051
+ int ith, int nth)
1052
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1053
+ }
1054
+
1055
+ void matmul(int64_t m, int64_t n) {
1056
+ mnpack(0, m, 0, n);
1057
+ }
1058
+
1059
+ private:
1060
+
1061
+ void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1062
+
1063
+ void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
1064
+ int64_t i, j;
1065
+ float *aoffset = NULL, *boffset = NULL;
1066
+ float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1067
+ float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1068
+
1069
+ aoffset = const_cast<float*>(a);
1070
+ boffset = vec;
1071
+ j = (rows >> 3);
1072
+ if (j > 0) {
1073
+ do {
1074
+ aoffset1 = aoffset;
1075
+ aoffset2 = aoffset1 + lda;
1076
+ aoffset3 = aoffset2 + lda;
1077
+ aoffset4 = aoffset3 + lda;
1078
+ aoffset5 = aoffset4 + lda;
1079
+ aoffset6 = aoffset5 + lda;
1080
+ aoffset7 = aoffset6 + lda;
1081
+ aoffset8 = aoffset7 + lda;
1082
+ aoffset += 8 * lda;
1083
+ i = (cols >> 3);
1084
+ if (i > 0) {
1085
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1086
+ vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
1087
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1088
+ do {
1089
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1090
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1091
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1092
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1093
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
1094
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
1095
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
1096
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
1097
+ __builtin_vsx_disassemble_pair(c1, &C1);
1098
+ __builtin_vsx_disassemble_pair(c2, &C2);
1099
+ __builtin_vsx_disassemble_pair(c3, &C3);
1100
+ __builtin_vsx_disassemble_pair(c4, &C4);
1101
+ __builtin_vsx_disassemble_pair(c5, &C5);
1102
+ __builtin_vsx_disassemble_pair(c6, &C6);
1103
+ __builtin_vsx_disassemble_pair(c7, &C7);
1104
+ __builtin_vsx_disassemble_pair(c8, &C8);
1105
+
1106
+ t1 = vec_mergeh(c1[0], c2[0]);
1107
+ t2 = vec_mergeh(c3[0], c4[0]);
1108
+ t3 = vec_mergeh(c5[0], c6[0]);
1109
+ t4 = vec_mergeh(c7[0], c8[0]);
1110
+ t5 = vec_xxpermdi(t1, t2, 0);
1111
+ t6 = vec_xxpermdi(t3, t4, 0);
1112
+ t7 = vec_xxpermdi(t1, t2, 3);
1113
+ t8 = vec_xxpermdi(t3, t4, 3);
1114
+ vec_xst(t5, 0, boffset);
1115
+ vec_xst(t6, 0, boffset+4);
1116
+ vec_xst(t7, 0, boffset+8);
1117
+ vec_xst(t8, 0, boffset+12);
1118
+
1119
+ t1 = vec_mergel(c1[0], c2[0]);
1120
+ t2 = vec_mergel(c3[0], c4[0]);
1121
+ t3 = vec_mergel(c5[0], c6[0]);
1122
+ t4 = vec_mergel(c7[0], c8[0]);
1123
+ t5 = vec_xxpermdi(t1, t2, 0);
1124
+ t6 = vec_xxpermdi(t3, t4, 0);
1125
+ t7 = vec_xxpermdi(t1, t2, 3);
1126
+ t8 = vec_xxpermdi(t3, t4, 3);
1127
+ vec_xst(t5, 0, boffset+16);
1128
+ vec_xst(t6, 0, boffset+20);
1129
+ vec_xst(t7, 0, boffset+24);
1130
+ vec_xst(t8, 0, boffset+28);
1131
+
1132
+ t1 = vec_mergeh(c1[1], c2[1]);
1133
+ t2 = vec_mergeh(c3[1], c4[1]);
1134
+ t3 = vec_mergeh(c5[1], c6[1]);
1135
+ t4 = vec_mergeh(c7[1], c8[1]);
1136
+ t5 = vec_xxpermdi(t1, t2, 0);
1137
+ t6 = vec_xxpermdi(t3, t4, 0);
1138
+ t7 = vec_xxpermdi(t1, t2, 3);
1139
+ t8 = vec_xxpermdi(t3, t4, 3);
1140
+ vec_xst(t5, 0, boffset+32);
1141
+ vec_xst(t6, 0, boffset+36);
1142
+ vec_xst(t7, 0, boffset+40);
1143
+ vec_xst(t8, 0, boffset+44);
1144
+
1145
+ t1 = vec_mergel(c1[1], c2[1]);
1146
+ t2 = vec_mergel(c3[1], c4[1]);
1147
+ t3 = vec_mergel(c5[1], c6[1]);
1148
+ t4 = vec_mergel(c7[1], c8[1]);
1149
+ t5 = vec_xxpermdi(t1, t2, 0);
1150
+ t6 = vec_xxpermdi(t3, t4, 0);
1151
+ t7 = vec_xxpermdi(t1, t2, 3);
1152
+ t8 = vec_xxpermdi(t3, t4, 3);
1153
+ vec_xst(t5, 0, boffset+48);
1154
+ vec_xst(t6, 0, boffset+52);
1155
+ vec_xst(t7, 0, boffset+56);
1156
+ vec_xst(t8, 0, boffset+60);
1157
+
1158
+ aoffset1 += 8*lda;
1159
+ aoffset2 += 8*lda;
1160
+ aoffset3 += 8*lda;
1161
+ aoffset4 += 8*lda;
1162
+ boffset += 64;
1163
+ i--;
1164
+ } while(i > 0);
1165
+ }
1166
+ if (cols & 4) {
1167
+ vector float c1, c2, c3, c4, c5, c6, c7, c8;
1168
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1169
+ c1 = vec_xl(0, aoffset1);
1170
+ c2 = vec_xl(0, aoffset2);
1171
+ c3 = vec_xl(0, aoffset3);
1172
+ c4 = vec_xl(0, aoffset4);
1173
+ c5 = vec_xl(0, aoffset5);
1174
+ c6 = vec_xl(0, aoffset6);
1175
+ c7 = vec_xl(0, aoffset7);
1176
+ c8 = vec_xl(0, aoffset8);
1177
+
1178
+ t1 = vec_mergeh(c1, c2);
1179
+ t2 = vec_mergeh(c3, c4);
1180
+ t3 = vec_mergeh(c5, c6);
1181
+ t4 = vec_mergeh(c7, c8);
1182
+ t5 = vec_xxpermdi(t1, t2, 0);
1183
+ t6 = vec_xxpermdi(t3, t4, 0);
1184
+ t7 = vec_xxpermdi(t1, t2, 3);
1185
+ t8 = vec_xxpermdi(t3, t4, 3);
1186
+ vec_xst(t5, 0, boffset);
1187
+ vec_xst(t6, 0, boffset+4);
1188
+ vec_xst(t7, 0, boffset+8);
1189
+ vec_xst(t8, 0, boffset+12);
1190
+
1191
+ t1 = vec_mergel(c1, c2);
1192
+ t2 = vec_mergel(c3, c4);
1193
+ t3 = vec_mergel(c5, c6);
1194
+ t4 = vec_mergel(c7, c8);
1195
+ t5 = vec_xxpermdi(t1, t2, 0);
1196
+ t6 = vec_xxpermdi(t3, t4, 0);
1197
+ t7 = vec_xxpermdi(t1, t2, 3);
1198
+ t8 = vec_xxpermdi(t3, t4, 3);
1199
+ vec_xst(t5, 0, boffset+16);
1200
+ vec_xst(t6, 0, boffset+20);
1201
+ vec_xst(t7, 0, boffset+24);
1202
+ vec_xst(t8, 0, boffset+28);
1203
+ }
1204
+ j--;
1205
+ } while(j > 0);
1206
+ }
1207
+
1208
+ if (rows & 4) {
1209
+ aoffset1 = aoffset;
1210
+ aoffset2 = aoffset1 + lda;
1211
+ aoffset3 = aoffset2 + lda;
1212
+ aoffset4 = aoffset3 + lda;
1213
+ aoffset += 4 * lda;
1214
+ i = (cols >> 3);
1215
+ if (i > 0) {
1216
+ __vector_pair C1, C2, C3, C4;
1217
+ vector float c1[2], c2[2], c3[2], c4[2];
1218
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1219
+ do {
1220
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1221
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1222
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1223
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1224
+ __builtin_vsx_disassemble_pair(c1, &C1);
1225
+ __builtin_vsx_disassemble_pair(c2, &C2);
1226
+ __builtin_vsx_disassemble_pair(c3, &C3);
1227
+ __builtin_vsx_disassemble_pair(c4, &C4);
1228
+
1229
+ t1 = vec_mergeh(c1[0], c2[0]);
1230
+ t2 = vec_mergeh(c3[0], c4[0]);
1231
+ t3 = vec_mergel(c1[0], c2[0]);
1232
+ t4 = vec_mergel(c3[0], c4[0]);
1233
+ t5 = vec_xxpermdi(t1, t2, 0);
1234
+ t6 = vec_xxpermdi(t1, t2, 3);
1235
+ t7 = vec_xxpermdi(t3, t4, 0);
1236
+ t8 = vec_xxpermdi(t3, t4, 3);
1237
+ vec_xst(t5, 0, boffset);
1238
+ vec_xst(t6, 0, boffset+4);
1239
+ vec_xst(t7, 0, boffset+8);
1240
+ vec_xst(t8, 0, boffset+12);
1241
+
1242
+ t1 = vec_mergeh(c1[1], c2[1]);
1243
+ t2 = vec_mergeh(c3[1], c4[1]);
1244
+ t3 = vec_mergel(c1[1], c2[1]);
1245
+ t4 = vec_mergel(c3[1], c4[1]);
1246
+ t5 = vec_xxpermdi(t1, t2, 0);
1247
+ t6 = vec_xxpermdi(t1, t2, 3);
1248
+ t7 = vec_xxpermdi(t3, t4, 0);
1249
+ t8 = vec_xxpermdi(t3, t4, 3);
1250
+ vec_xst(t5, 0, boffset+16);
1251
+ vec_xst(t6, 0, boffset+20);
1252
+ vec_xst(t7, 0, boffset+24);
1253
+ vec_xst(t8, 0, boffset+28);
1254
+
1255
+ aoffset1 += 8*lda;
1256
+ aoffset2 += 8*lda;
1257
+ aoffset3 += 8*lda;
1258
+ aoffset4 += 8*lda;
1259
+ boffset += 32;
1260
+ i--;
1261
+ } while(i > 0);
1262
+ }
1263
+
1264
+ if (cols & 4) {
1265
+ vector float c1, c2, c3, c4;
1266
+ vector float t1, t2, t3, t4;
1267
+ c1 = vec_xl(0, aoffset1);
1268
+ c2 = vec_xl(0, aoffset2);
1269
+ c3 = vec_xl(0, aoffset3);
1270
+ c4 = vec_xl(0, aoffset4);
1271
+
1272
+ t1 = vec_mergeh(c1, c2);
1273
+ t2 = vec_mergeh(c3, c4);
1274
+ t3 = vec_xxpermdi(t1, t2, 0);
1275
+ t4 = vec_xxpermdi(t1, t2, 3);
1276
+ vec_xst(t3, 0, boffset);
1277
+ vec_xst(t4, 0, boffset+4);
1278
+
1279
+ t1 = vec_mergel(c1, c2);
1280
+ t2 = vec_mergel(c3, c4);
1281
+ t3 = vec_xxpermdi(t1, t2, 0);
1282
+ t4 = vec_xxpermdi(t1, t2, 3);
1283
+ vec_xst(t3, 0, boffset+8);
1284
+ vec_xst(t4, 0, boffset+12);
1285
+ }
1286
+ }
1287
+ if (rows & 3) {
1288
+ aoffset1 = aoffset;
1289
+ aoffset2 = aoffset1 + lda;
1290
+ aoffset3 = aoffset2 + lda;
1291
+ if (cols & 4) {
1292
+ vector float c1, c2, c3, c4 = {0};
1293
+ vector float t1, t2, t3, t4;
1294
+ c1 = vec_xl(0, aoffset1);
1295
+ c2 = vec_xl(0, aoffset2);
1296
+ c3 = vec_xl(0, aoffset3);
1297
+
1298
+ t1 = vec_mergeh(c1, c2);
1299
+ t2 = vec_mergeh(c3, c4);
1300
+ t3 = vec_xxpermdi(t1, t2, 0);
1301
+ t4 = vec_xxpermdi(t1, t2, 3);
1302
+ vec_xst(t3, 0, boffset);
1303
+ vec_xst(t4, 0, boffset+4);
1304
+
1305
+ t1 = vec_mergel(c1, c2);
1306
+ t2 = vec_mergel(c3, c4);
1307
+ t3 = vec_xxpermdi(t1, t2, 0);
1308
+ t4 = vec_xxpermdi(t1, t2, 3);
1309
+ vec_xst(t3, 0, boffset+8);
1310
+ vec_xst(t4, 0, boffset+12);
1311
+ }
1312
+ }
1313
+ }
1314
+
1315
+ void KERNEL_4x4(int64_t ii, int64_t jj) {
1316
+ vec_t vec_A[4], vec_B[4], vec_C[4];
1317
+ acc_t acc_0;
1318
+ __builtin_mma_xxsetaccz(&acc_0);
1319
+ for (int l = 0; l < k; l+=4) {
1320
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1321
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1322
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1323
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1324
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1325
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1326
+ }
1327
+ SAVE_ACC(&acc_0, ii, jj);
1328
+ }
1329
+
1330
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1331
+ vec_t vec_A[4], vec_B[8], vec_C[4];
1332
+ acc_t acc_0, acc_1;
1333
+ __builtin_mma_xxsetaccz(&acc_0);
1334
+ __builtin_mma_xxsetaccz(&acc_1);
1335
+ for (int64_t l = 0; l < k; l+=4) {
1336
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1337
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
1338
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
1339
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
1340
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
1341
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
1342
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
1343
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
1344
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
1345
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
1346
+ }
1347
+ SAVE_ACC(&acc_0, ii, jj);
1348
+ SAVE_ACC(&acc_1, ii, jj+4);
1349
+ }
1350
+
1351
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1352
+ vec_t vec_A[8], vec_B[4], vec_C[4];
1353
+ acc_t acc_0, acc_1;
1354
+ __builtin_mma_xxsetaccz(&acc_0);
1355
+ __builtin_mma_xxsetaccz(&acc_1);
1356
+ for (int64_t l = 0; l < k; l+=4) {
1357
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
1358
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1359
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
1360
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
1361
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
1362
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
1363
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
1364
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
1365
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
1366
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
1367
+ }
1368
+ SAVE_ACC(&acc_0, ii, jj);
1369
+ SAVE_ACC(&acc_1, ii+4, jj);
1370
+ }
1371
+
1372
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1373
+ vec_t vec_A[16], vec_B[16], vec_C[4];
1374
+ acc_t acc_0, acc_1, acc_2, acc_3;
1375
+ __builtin_mma_xxsetaccz(&acc_0);
1376
+ __builtin_mma_xxsetaccz(&acc_1);
1377
+ __builtin_mma_xxsetaccz(&acc_2);
1378
+ __builtin_mma_xxsetaccz(&acc_3);
1379
+ for (int l = 0; l < k; l+=8) {
1380
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
1381
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
1382
+ for(int x = 0; x < 16; x+=2) {
1383
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
1384
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
1385
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
1386
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
1387
+ }
1388
+ }
1389
+ SAVE_ACC(&acc_0, ii, jj);
1390
+ SAVE_ACC(&acc_1, ii, jj+4);
1391
+ SAVE_ACC(&acc_2, ii+4, jj);
1392
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1393
+ }
1394
+
1395
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1396
+ int64_t mc, nc, mp, np;
1397
+ int m_rem = MIN(m - m0, 16);
1398
+ int n_rem = MIN(n - n0, 16);
1399
+ if (m_rem >= 16 && n_rem >= 8) {
1400
+ mc = 8;
1401
+ nc = 8;
1402
+ gemm<8,8>(m0, m, n0, n);
1403
+ } else if(m_rem >= 8 && n_rem >= 16) {
1404
+ mc = 8;
1405
+ nc = 8;
1406
+ gemm<8,8>(m0, m, n0, n);
1407
+ } else if (m_rem >= 8 && n_rem >= 8) {
1408
+ mc = 8;
1409
+ nc = 8;
1410
+ gemm<8,8>(m0, m, n0, n);
1411
+ } else if (m_rem >= 4 && n_rem >= 8) {
1412
+ mc = 4;
1413
+ nc = 8;
1414
+ gemm<4,8>(m0, m, n0, n);
1415
+ } else if (m_rem >= 8 && n_rem >= 4) {
1416
+ mc = 8;
1417
+ nc = 4;
1418
+ gemm<8,4>(m0, m, n0, n);
1419
+ } else if (m_rem >= 4 && n_rem >= 4) {
1420
+ mc = 4;
1421
+ nc = 4;
1422
+ gemm<4,4>(m0, m, n0, n);
1423
+ } else if ((m_rem < 4) && (n_rem > 4)) {
1424
+ nc = 4;
1425
+ switch(m_rem) {
1426
+ case 1:
1427
+ mc = 1;
1428
+ gemm_small(m0, m, n0, n, mc, nc);
1429
+ break;
1430
+ case 2:
1431
+ mc = 2;
1432
+ gemm_small(m0, m, n0, n, mc, nc);
1433
+ break;
1434
+ case 3:
1435
+ mc = 3;
1436
+ gemm_small(m0, m, n0, n, mc, nc);
1437
+ break;
1438
+ default:
1439
+ return;
1440
+ }
1441
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1442
+ mc = 4;
1443
+ switch(n_rem) {
1444
+ case 1:
1445
+ nc = 1;
1446
+ gemm_small(m0, m, n0, n, mc, nc);
1447
+ break;
1448
+ case 2:
1449
+ nc = 2;
1450
+ gemm_small(m0, m, n0, n, mc, nc);
1451
+ break;
1452
+ case 3:
1453
+ nc = 3;
1454
+ gemm_small(m0, m, n0, n, mc, nc);
1455
+ break;
1456
+ default:
1457
+ return;
1458
+ }
1459
+ } else {
1460
+ switch((m_rem << 4) | n_rem) {
1461
+ case 0x43:
1462
+ mc = 4;
1463
+ nc = 3;
1464
+ gemm_small(m0, m, n0, n, mc, nc);
1465
+ break;
1466
+ case 0x42:
1467
+ mc = 4;
1468
+ nc = 2;
1469
+ gemm_small(m0, m, n0, n, mc, nc);
1470
+ break;
1471
+ case 0x41:
1472
+ mc = 4;
1473
+ nc = 1;
1474
+ gemm_small(m0, m, n0, n, mc, nc);
1475
+ break;
1476
+ case 0x34:
1477
+ mc = 3;
1478
+ nc = 4;
1479
+ gemm_small(m0, m, n0, n, mc, nc);
1480
+ break;
1481
+ case 0x33:
1482
+ mc = 3;
1483
+ nc = 3;
1484
+ gemm_small(m0, m, n0, n, mc, nc);
1485
+ break;
1486
+ case 0x32:
1487
+ mc = 3;
1488
+ nc = 2;
1489
+ gemm_small(m0, m, n0, n, mc, nc);
1490
+ break;
1491
+ case 0x31:
1492
+ mc = 3;
1493
+ nc = 1;
1494
+ gemm_small(m0, m, n0, n, mc, nc);
1495
+ break;
1496
+ case 0x24:
1497
+ mc = 2;
1498
+ nc = 4;
1499
+ gemm_small(m0, m, n0, n, mc, nc);
1500
+ break;
1501
+ case 0x23:
1502
+ mc = 2;
1503
+ nc = 3;
1504
+ gemm_small(m0, m, n0, n, mc, nc);
1505
+ break;
1506
+ case 0x22:
1507
+ mc = 2;
1508
+ nc = 2;
1509
+ gemm_small(m0, m, n0, n, mc, nc);
1510
+ break;
1511
+ case 0x21:
1512
+ mc = 2;
1513
+ nc = 1;
1514
+ gemm_small(m0, m, n0, n, mc, nc);
1515
+ break;
1516
+ case 0x14:
1517
+ mc = 1;
1518
+ nc = 4;
1519
+ gemm_small(m0, m, n0, n, mc, nc);
1520
+ break;
1521
+ case 0x13:
1522
+ mc = 1;
1523
+ nc = 3;
1524
+ gemm_small(m0, m, n0, n, mc, nc);
1525
+ break;
1526
+ case 0x12:
1527
+ mc = 1;
1528
+ nc = 2;
1529
+ gemm_small(m0, m, n0, n, mc, nc);
1530
+ break;
1531
+ case 0x11:
1532
+ mc = 1;
1533
+ nc = 1;
1534
+ gemm_small(m0, m, n0, n, mc, nc);
1535
+ break;
1536
+ default:
1537
+ return;
1538
+ }
1539
+ }
1540
+ mp = m0 + (m - m0) / mc * mc;
1541
+ np = n0 + (n - n0) / nc * nc;
1542
+ mnpack(mp, m, n0, np);
1543
+ mnpack(m0, m, np, n);
1544
+ }
1545
+
1546
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
1547
+ int64_t ytiles = (m - m0) / RM;
1548
+ int64_t xtiles = (n - n0) / RN;
1549
+ int64_t tiles = xtiles * ytiles;
1550
+ int64_t duty = (tiles + nth - 1) / nth;
1551
+ int64_t start = duty * ith;
1552
+ int64_t end = start + duty;
1553
+ if (end > tiles)
1554
+ end = tiles;
1555
+ for (int64_t job = start; job < end; ++job) {
1556
+ int64_t ii = m0 + job / xtiles * RM;
1557
+ int64_t jj = n0 + job % xtiles * RN;
1558
+ vec_t vec_C[4];
1559
+ acc_t acc_0;
1560
+ __builtin_mma_xxsetaccz(&acc_0);
1561
+ vec_t vec_A[4], vec_B[4];
1562
+ for (int l=0; l<k; l+=4) {
1563
+ if (RN >= 4 && RM == 1) {
1564
+ float* a = const_cast<float*>(A+(ii)*lda+l);
1565
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1566
+ vec_A[0] = (vec_t)vec_xl(0,a);
1567
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
1568
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
1569
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
1570
+ } else {
1571
+ READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
1572
+ READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
1573
+ }
1574
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1575
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1576
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1577
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1578
+ }
1579
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1580
+ for (int I = 0; I < RM; I++) {
1581
+ for (int J = 0; J < RN; J++) {
1582
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
1583
+ }
1584
+ }
1585
+ }
1586
+ }
1587
+
1588
+ template <int RM, int RN>
1589
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1590
+ int64_t ytiles = (m - m0) / RM;
1591
+ int64_t xtiles = (n - n0) / RN;
1592
+ int64_t tiles = xtiles * ytiles;
1593
+ int64_t duty = (tiles + nth - 1) / nth;
1594
+ int64_t start = duty * ith;
1595
+ int64_t end = start + duty;
1596
+ if (RM == 4 && RN == 4) {
1597
+ kernel = &tinyBLAS_PPC::KERNEL_4x4;
1598
+ } else if (RM == 4 && RN == 8) {
1599
+ kernel = &tinyBLAS_PPC::KERNEL_4x8;
1600
+ } else if (RM == 8 && RN == 4) {
1601
+ kernel = &tinyBLAS_PPC::KERNEL_8x4;
1602
+ } else if (RM == 8 && RN == 8) {
1603
+ kernel = &tinyBLAS_PPC::KERNEL_8x8;
1604
+ }
1605
+ if (end > tiles)
1606
+ end = tiles;
1607
+ for (int64_t job = start; job < end; ++job) {
1608
+ int64_t ii = m0 + job / xtiles * RM;
1609
+ int64_t jj = n0 + job % xtiles * RN;
1610
+ (this->*kernel)(ii, jj);
1611
+ }
1612
+ }
1613
+
1614
+ const TA *const A;
1615
+ const TB *const B;
1616
+ TC *C;
1617
+ TA *At;
1618
+ TB *Bt;
1619
+ const int64_t k;
1620
+ const int64_t lda;
1621
+ const int64_t ldb;
1622
+ const int64_t ldc;
1623
+ const int ith;
1624
+ const int nth;
1625
+ };
1626
+ #endif
1029
1627
  } // namespace
1030
1628
 
1031
1629
  /**
@@ -1114,6 +1712,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
1114
1712
  ith, nth};
1115
1713
  tb.matmul(m, n);
1116
1714
  return true;
1715
+ #elif defined(__MMA__)
1716
+ if (k % 8)
1717
+ return false;
1718
+ tinyBLAS_PPC<float, float, float> tb{
1719
+ k, (const float *)A, lda,
1720
+ (const float *)B, ldb,
1721
+ (float *)C, ldc,
1722
+ ith, nth};
1723
+ tb.matmul(m, n);
1724
+ return true;
1117
1725
  #else
1118
1726
  return false;
1119
1727
  #endif