cui-llama.rn 1.4.6 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -2
- package/android/src/main/jni.cpp +52 -34
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/binary-ops.cpp +158 -0
- package/cpp/binary-ops.h +16 -0
- package/cpp/chat.cpp +1769 -1779
- package/cpp/chat.h +9 -1
- package/cpp/common.cpp +20 -522
- package/cpp/common.h +13 -36
- package/cpp/cpu-common.h +72 -0
- package/cpp/ggml-common.h +12 -6
- package/cpp/ggml-cpu-aarch64.cpp +1557 -80
- package/cpp/ggml-cpu-impl.h +2 -21
- package/cpp/ggml-cpu-quants.c +904 -405
- package/cpp/ggml-cpu.c +909 -13237
- package/cpp/ggml-impl.h +50 -23
- package/cpp/ggml-metal-impl.h +77 -3
- package/cpp/ggml-metal.m +794 -580
- package/cpp/ggml.c +92 -3
- package/cpp/ggml.h +29 -5
- package/cpp/gguf.cpp +1 -0
- package/cpp/llama-adapter.cpp +55 -20
- package/cpp/llama-adapter.h +11 -9
- package/cpp/llama-arch.cpp +217 -16
- package/cpp/llama-arch.h +25 -0
- package/cpp/llama-batch.h +2 -2
- package/cpp/llama-chat.cpp +54 -2
- package/cpp/llama-chat.h +3 -0
- package/cpp/llama-context.cpp +2294 -1238
- package/cpp/llama-context.h +214 -77
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +1695 -0
- package/cpp/llama-graph.h +592 -0
- package/cpp/llama-hparams.cpp +8 -0
- package/cpp/llama-hparams.h +17 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache.cpp +965 -303
- package/cpp/llama-kv-cache.h +145 -151
- package/cpp/llama-memory.cpp +1 -0
- package/cpp/llama-memory.h +21 -0
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +10 -5
- package/cpp/llama-model-loader.h +5 -3
- package/cpp/llama-model.cpp +9194 -201
- package/cpp/llama-model.h +40 -1
- package/cpp/llama-sampling.cpp +5 -0
- package/cpp/llama-vocab.cpp +36 -5
- package/cpp/llama.cpp +51 -9984
- package/cpp/llama.h +102 -22
- package/cpp/log.cpp +34 -0
- package/cpp/minja/chat-template.hpp +15 -7
- package/cpp/minja/minja.hpp +120 -94
- package/cpp/ops.cpp +8723 -0
- package/cpp/ops.h +128 -0
- package/cpp/rn-llama.cpp +44 -53
- package/cpp/rn-llama.h +2 -12
- package/cpp/sampling.cpp +3 -0
- package/cpp/sgemm.cpp +533 -88
- package/cpp/simd-mappings.h +888 -0
- package/cpp/speculative.cpp +4 -4
- package/cpp/unary-ops.cpp +186 -0
- package/cpp/unary-ops.h +28 -0
- package/cpp/vec.cpp +258 -0
- package/cpp/vec.h +802 -0
- package/ios/CMakeLists.txt +5 -2
- package/ios/RNLlama.mm +2 -2
- package/ios/RNLlamaContext.mm +40 -24
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +6 -4
- package/src/index.ts +3 -1
- package/cpp/chat-template.hpp +0 -529
- package/cpp/minja.hpp +0 -2915
package/cpp/sgemm.cpp
CHANGED
@@ -55,6 +55,7 @@
|
|
55
55
|
|
56
56
|
#include <atomic>
|
57
57
|
#include <array>
|
58
|
+
#include <type_traits>
|
58
59
|
|
59
60
|
#ifdef _MSC_VER
|
60
61
|
#define NOINLINE __declspec(noinline)
|
@@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC {
|
|
1092
1093
|
}
|
1093
1094
|
}
|
1094
1095
|
|
1095
|
-
template<typename VA, typename VB>
|
1096
|
-
void
|
1096
|
+
template<typename VA, typename VB, int size>
|
1097
|
+
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
|
1097
1098
|
int64_t i, j;
|
1098
1099
|
TA *aoffset = NULL;
|
1099
1100
|
VA *vecOffset = NULL;
|
1100
1101
|
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
1101
1102
|
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
1103
|
+
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
1104
|
+
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
1105
|
+
VB t1, t2, t3, t4, t5, t6, t7, t8;
|
1106
|
+
const vector signed char lowMask = vec_splats((signed char)0xF);
|
1107
|
+
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
1108
|
+
const vector signed char v8 = vec_splats((signed char)0x8);
|
1109
|
+
aoffset = const_cast<TA*>(a);
|
1110
|
+
vecOffset = vec;
|
1111
|
+
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
1112
|
+
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
1113
|
+
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
1114
|
+
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
1115
|
+
vector signed int vsum = {0};
|
1116
|
+
vector signed int vsum2 = {0};
|
1117
|
+
|
1118
|
+
j = (rows >> 3);
|
1119
|
+
if (j > 0) {
|
1120
|
+
do {
|
1121
|
+
aoffset1 = aoffset;
|
1122
|
+
aoffset2 = aoffset1 + lda;
|
1123
|
+
aoffset3 = aoffset2 + lda;
|
1124
|
+
aoffset4 = aoffset3 + lda;
|
1125
|
+
aoffset5 = aoffset4 + lda;
|
1126
|
+
aoffset6 = aoffset5 + lda;
|
1127
|
+
aoffset7 = aoffset6 + lda;
|
1128
|
+
aoffset8 = aoffset7 + lda;
|
1129
|
+
aoffset += 8 * lda;
|
1130
|
+
|
1131
|
+
i = (cols >> 2);
|
1132
|
+
if (i > 0) {
|
1133
|
+
do {
|
1134
|
+
c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
1135
|
+
c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
1136
|
+
c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
1137
|
+
c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
|
1138
|
+
c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
|
1139
|
+
c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
|
1140
|
+
c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
|
1141
|
+
c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
|
1142
|
+
|
1143
|
+
c1[0] = vec_and(c1[1], lowMask);
|
1144
|
+
c1[1] = vec_sr(c1[1], v4);
|
1145
|
+
c1[0] = vec_sub(c1[0], v8);
|
1146
|
+
c1[1] = vec_sub(c1[1], v8);
|
1147
|
+
vsum = vec_sum4s(c1[0], vsum);
|
1148
|
+
vsum2 = vec_sum4s(c1[1], vsum2);
|
1149
|
+
vsum = vec_add(vsum, vsum2);
|
1150
|
+
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1151
|
+
vsum = vec_splats(0);
|
1152
|
+
vsum2 = vec_splats(0);
|
1153
|
+
|
1154
|
+
c2[0] = vec_and(c2[1], lowMask);
|
1155
|
+
c2[1] = vec_sr(c2[1], v4);
|
1156
|
+
c2[0] = vec_sub(c2[0], v8);
|
1157
|
+
c2[1] = vec_sub(c2[1], v8);
|
1158
|
+
vsum = vec_sum4s(c2[0], vsum);
|
1159
|
+
vsum2 = vec_sum4s(c2[1], vsum2);
|
1160
|
+
vsum = vec_add(vsum, vsum2);
|
1161
|
+
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1162
|
+
vsum = vec_splats(0);
|
1163
|
+
vsum2 = vec_splats(0);
|
1164
|
+
|
1165
|
+
c3[0] = vec_and(c3[1], lowMask);
|
1166
|
+
c3[1] = vec_sr(c3[1], v4);
|
1167
|
+
c3[0] = vec_sub(c3[0], v8);
|
1168
|
+
c3[1] = vec_sub(c3[1], v8);
|
1169
|
+
vsum = vec_sum4s(c3[0], vsum);
|
1170
|
+
vsum2 = vec_sum4s(c3[1], vsum2);
|
1171
|
+
vsum = vec_add(vsum, vsum2);
|
1172
|
+
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1173
|
+
vsum = vec_splats(0);
|
1174
|
+
vsum2 = vec_splats(0);
|
1175
|
+
|
1176
|
+
c4[0] = vec_and(c4[1], lowMask);
|
1177
|
+
c4[1] = vec_sr(c4[1], v4);
|
1178
|
+
c4[0] = vec_sub(c4[0], v8);
|
1179
|
+
c4[1] = vec_sub(c4[1], v8);
|
1180
|
+
vsum = vec_sum4s(c4[0], vsum);
|
1181
|
+
vsum2 = vec_sum4s(c4[1], vsum2);
|
1182
|
+
vsum = vec_add(vsum, vsum2);
|
1183
|
+
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1184
|
+
vsum = vec_splats(0);
|
1185
|
+
vsum2 = vec_splats(0);
|
1186
|
+
|
1187
|
+
c5[0] = vec_and(c5[1], lowMask);
|
1188
|
+
c5[1] = vec_sr(c5[1], v4);
|
1189
|
+
c5[0] = vec_sub(c5[0], v8);
|
1190
|
+
c5[1] = vec_sub(c5[1], v8);
|
1191
|
+
vsum = vec_sum4s(c5[0], vsum);
|
1192
|
+
vsum2 = vec_sum4s(c5[1], vsum2);
|
1193
|
+
vsum = vec_add(vsum, vsum2);
|
1194
|
+
comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1195
|
+
vsum = vec_splats(0);
|
1196
|
+
vsum2 = vec_splats(0);
|
1197
|
+
|
1198
|
+
c6[0] = vec_and(c6[1], lowMask);
|
1199
|
+
c6[1] = vec_sr(c6[1], v4);
|
1200
|
+
c6[0] = vec_sub(c6[0], v8);
|
1201
|
+
c6[1] = vec_sub(c6[1], v8);
|
1202
|
+
vsum = vec_sum4s(c6[0], vsum);
|
1203
|
+
vsum2 = vec_sum4s(c6[1], vsum2);
|
1204
|
+
vsum = vec_add(vsum, vsum2);
|
1205
|
+
comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1206
|
+
vsum = vec_splats(0);
|
1207
|
+
vsum2 = vec_splats(0);
|
1208
|
+
|
1209
|
+
c7[0] = vec_and(c7[1], lowMask);
|
1210
|
+
c7[1] = vec_sr(c7[1], v4);
|
1211
|
+
c7[0] = vec_sub(c7[0], v8);
|
1212
|
+
c7[1] = vec_sub(c7[1], v8);
|
1213
|
+
vsum = vec_sum4s(c7[0], vsum);
|
1214
|
+
vsum2 = vec_sum4s(c7[1], vsum2);
|
1215
|
+
vsum = vec_add(vsum, vsum2);
|
1216
|
+
comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1217
|
+
vsum = vec_splats(0);
|
1218
|
+
vsum2 = vec_splats(0);
|
1219
|
+
|
1220
|
+
c8[0] = vec_and(c8[1], lowMask);
|
1221
|
+
c8[1] = vec_sr(c8[1], v4);
|
1222
|
+
c8[0] = vec_sub(c8[0], v8);
|
1223
|
+
c8[1] = vec_sub(c8[1], v8);
|
1224
|
+
vsum = vec_sum4s(c8[0], vsum);
|
1225
|
+
vsum2 = vec_sum4s(c8[1], vsum2);
|
1226
|
+
vsum = vec_add(vsum, vsum2);
|
1227
|
+
comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1228
|
+
vsum = vec_splats(0);
|
1229
|
+
vsum2 = vec_splats(0);
|
1230
|
+
|
1231
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
1232
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
1233
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
1234
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
1235
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1236
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1237
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1238
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1239
|
+
vec_xst(t5, 0, vecOffset);
|
1240
|
+
vec_xst(t6, 0, vecOffset+16);
|
1241
|
+
vec_xst(t7, 0, vecOffset+32);
|
1242
|
+
vec_xst(t8, 0, vecOffset+48);
|
1243
|
+
|
1244
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
1245
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
1246
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
1247
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
1248
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1249
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1250
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1251
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1252
|
+
vec_xst(t5, 0, vecOffset+64);
|
1253
|
+
vec_xst(t6, 0, vecOffset+80);
|
1254
|
+
vec_xst(t7, 0, vecOffset+96);
|
1255
|
+
vec_xst(t8, 0, vecOffset+112);
|
1256
|
+
|
1257
|
+
t1 = vec_perm(c5[0], c6[0], swiz1);
|
1258
|
+
t2 = vec_perm(c5[0], c6[0], swiz2);
|
1259
|
+
t3 = vec_perm(c7[0], c8[0], swiz1);
|
1260
|
+
t4 = vec_perm(c7[0], c8[0], swiz2);
|
1261
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1262
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1263
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1264
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1265
|
+
vec_xst(t5, 0, vecOffset+128);
|
1266
|
+
vec_xst(t6, 0, vecOffset+144);
|
1267
|
+
vec_xst(t7, 0, vecOffset+160);
|
1268
|
+
vec_xst(t8, 0, vecOffset+176);
|
1269
|
+
|
1270
|
+
t1 = vec_perm(c5[1], c6[1], swiz1);
|
1271
|
+
t2 = vec_perm(c5[1], c6[1], swiz2);
|
1272
|
+
t3 = vec_perm(c7[1], c8[1], swiz1);
|
1273
|
+
t4 = vec_perm(c7[1], c8[1], swiz2);
|
1274
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1275
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1276
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1277
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1278
|
+
vec_xst(t5, 0, vecOffset+192);
|
1279
|
+
vec_xst(t6, 0, vecOffset+208);
|
1280
|
+
vec_xst(t7, 0, vecOffset+224);
|
1281
|
+
vec_xst(t8, 0, vecOffset+240);
|
1282
|
+
|
1283
|
+
aoffset1 += lda;
|
1284
|
+
aoffset2 += lda;
|
1285
|
+
aoffset3 += lda;
|
1286
|
+
aoffset4 += lda;
|
1287
|
+
aoffset5 += lda;
|
1288
|
+
aoffset6 += lda;
|
1289
|
+
aoffset7 += lda;
|
1290
|
+
aoffset8 += lda;
|
1291
|
+
vecOffset += 256;
|
1292
|
+
i--;
|
1293
|
+
} while (i > 0);
|
1294
|
+
}
|
1295
|
+
j--;
|
1296
|
+
} while (j > 0);
|
1297
|
+
}
|
1298
|
+
|
1299
|
+
if (rows & 4) {
|
1300
|
+
aoffset1 = aoffset;
|
1301
|
+
aoffset2 = aoffset1 + lda;
|
1302
|
+
aoffset3 = aoffset2 + lda;
|
1303
|
+
aoffset4 = aoffset3 + lda;
|
1304
|
+
aoffset += 4 * lda;
|
1305
|
+
|
1306
|
+
i = (cols >> 2);
|
1307
|
+
if (i > 0) {
|
1308
|
+
do {
|
1309
|
+
c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
1310
|
+
c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
1311
|
+
c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
1312
|
+
c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
|
1313
|
+
|
1314
|
+
c1[0] = vec_and(c1[1], lowMask);
|
1315
|
+
c1[1] = vec_sr(c1[1], v4);
|
1316
|
+
c1[0] = vec_sub(c1[0], v8);
|
1317
|
+
c1[1] = vec_sub(c1[1], v8);
|
1318
|
+
vsum = vec_sum4s(c1[0], vsum);
|
1319
|
+
vsum2 = vec_sum4s(c1[1], vsum2);
|
1320
|
+
vsum = vec_add(vsum, vsum2);
|
1321
|
+
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1322
|
+
vsum = vec_splats(0);
|
1323
|
+
vsum2 = vec_splats(0);
|
1324
|
+
|
1325
|
+
c2[0] = vec_and(c2[1], lowMask);
|
1326
|
+
c2[1] = vec_sr(c2[1], v4);
|
1327
|
+
c2[0] = vec_sub(c2[0], v8);
|
1328
|
+
c2[1] = vec_sub(c2[1], v8);
|
1329
|
+
vsum = vec_sum4s(c2[0], vsum);
|
1330
|
+
vsum2 = vec_sum4s(c2[1], vsum2);
|
1331
|
+
vsum = vec_add(vsum, vsum2);
|
1332
|
+
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1333
|
+
vsum = vec_splats(0);
|
1334
|
+
vsum2 = vec_splats(0);
|
1335
|
+
|
1336
|
+
c3[0] = vec_and(c3[1], lowMask);
|
1337
|
+
c3[1] = vec_sr(c3[1], v4);
|
1338
|
+
c3[0] = vec_sub(c3[0], v8);
|
1339
|
+
c3[1] = vec_sub(c3[1], v8);
|
1340
|
+
vsum = vec_sum4s(c3[0], vsum);
|
1341
|
+
vsum2 = vec_sum4s(c3[1], vsum2);
|
1342
|
+
vsum = vec_add(vsum, vsum2);
|
1343
|
+
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1344
|
+
vsum = vec_splats(0);
|
1345
|
+
vsum2 = vec_splats(0);
|
1346
|
+
|
1347
|
+
c4[0] = vec_and(c4[1], lowMask);
|
1348
|
+
c4[1] = vec_sr(c4[1], v4);
|
1349
|
+
c4[0] = vec_sub(c4[0], v8);
|
1350
|
+
c4[1] = vec_sub(c4[1], v8);
|
1351
|
+
vsum = vec_sum4s(c4[0], vsum);
|
1352
|
+
vsum2 = vec_sum4s(c4[1], vsum2);
|
1353
|
+
vsum = vec_add(vsum, vsum2);
|
1354
|
+
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1355
|
+
vsum = vec_splats(0);
|
1356
|
+
vsum2 = vec_splats( 0);
|
1357
|
+
|
1358
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
1359
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
1360
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
1361
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
1362
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1363
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1364
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1365
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1366
|
+
vec_xst(t5, 0, vecOffset);
|
1367
|
+
vec_xst(t6, 0, vecOffset+16);
|
1368
|
+
vec_xst(t7, 0, vecOffset+32);
|
1369
|
+
vec_xst(t8, 0, vecOffset+48);
|
1370
|
+
|
1371
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
1372
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
1373
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
1374
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
1375
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1376
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1377
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1378
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1379
|
+
vec_xst(t5, 0, vecOffset+64);
|
1380
|
+
vec_xst(t6, 0, vecOffset+80);
|
1381
|
+
vec_xst(t7, 0, vecOffset+96);
|
1382
|
+
vec_xst(t8, 0, vecOffset+112);
|
1383
|
+
|
1384
|
+
aoffset1 += lda;
|
1385
|
+
aoffset2 += lda;
|
1386
|
+
aoffset3 += lda;
|
1387
|
+
aoffset4 += lda;
|
1388
|
+
vecOffset += 128;
|
1389
|
+
i--;
|
1390
|
+
} while (i > 0);
|
1391
|
+
}
|
1392
|
+
}
|
1393
|
+
|
1394
|
+
if (rows & 3) {
|
1395
|
+
aoffset1 = aoffset;
|
1396
|
+
aoffset2 = aoffset1 + lda;
|
1397
|
+
aoffset3 = aoffset2 + lda;
|
1398
|
+
i = (cols >> 2);
|
1399
|
+
if (i > 0) {
|
1400
|
+
do {
|
1401
|
+
switch(rows) {
|
1402
|
+
case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
1403
|
+
case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
1404
|
+
case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
1405
|
+
break;
|
1406
|
+
}
|
1407
|
+
c1[0] = vec_and(c1[1], lowMask);
|
1408
|
+
c1[1] = vec_sr(c1[1], v4);
|
1409
|
+
c1[0] = vec_sub(c1[0], v8);
|
1410
|
+
c1[1] = vec_sub(c1[1], v8);
|
1411
|
+
vsum = vec_sum4s(c1[0], vsum);
|
1412
|
+
vsum2 = vec_sum4s(c1[1], vsum2);
|
1413
|
+
vsum = vec_add(vsum, vsum2);
|
1414
|
+
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1415
|
+
vsum = vec_splats(0);
|
1416
|
+
vsum2 = vec_splats(0);
|
1417
|
+
|
1418
|
+
c2[0] = vec_and(c2[1], lowMask);
|
1419
|
+
c2[1] = vec_sr(c2[1], v4);
|
1420
|
+
c2[0] = vec_sub(c2[0], v8);
|
1421
|
+
c2[1] = vec_sub(c2[1], v8);
|
1422
|
+
vsum = vec_sum4s(c2[0], vsum);
|
1423
|
+
vsum2 = vec_sum4s(c2[1], vsum2);
|
1424
|
+
vsum = vec_add(vsum, vsum2);
|
1425
|
+
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1426
|
+
vsum = vec_splats(0);
|
1427
|
+
vsum2 = vec_splats(0);
|
1428
|
+
|
1429
|
+
c3[0] = vec_and(c3[1], lowMask);
|
1430
|
+
c3[1] = vec_sr(c3[1], v4);
|
1431
|
+
c3[0] = vec_sub(c3[0], v8);
|
1432
|
+
c3[1] = vec_sub(c3[1], v8);
|
1433
|
+
vsum = vec_sum4s(c3[0], vsum);
|
1434
|
+
vsum2 = vec_sum4s(c3[1], vsum2);
|
1435
|
+
vsum = vec_add(vsum, vsum2);
|
1436
|
+
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1437
|
+
vsum = vec_splats(0);
|
1438
|
+
vsum2 = vec_splats(0);
|
1439
|
+
|
1440
|
+
c4[0] = vec_and(c4[1], lowMask);
|
1441
|
+
c4[1] = vec_sr(c4[1], v4);
|
1442
|
+
c4[0] = vec_sub(c4[0], v8);
|
1443
|
+
c4[1] = vec_sub(c4[1], v8);
|
1444
|
+
vsum = vec_sum4s(c4[0], vsum);
|
1445
|
+
vsum2 = vec_sum4s(c4[1], vsum2);
|
1446
|
+
vsum = vec_add(vsum, vsum2);
|
1447
|
+
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
1448
|
+
vsum = vec_splats(0);
|
1449
|
+
vsum2 = vec_splats(0);
|
1450
|
+
|
1451
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
1452
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
1453
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
1454
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
1455
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1456
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1457
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1458
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1459
|
+
vec_xst(t5, 0, vecOffset);
|
1460
|
+
vec_xst(t6, 0, vecOffset+16);
|
1461
|
+
vec_xst(t7, 0, vecOffset+32);
|
1462
|
+
vec_xst(t8, 0, vecOffset+48);
|
1463
|
+
|
1464
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
1465
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
1466
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
1467
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
1468
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1469
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1470
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1471
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1472
|
+
vec_xst(t5, 0, vecOffset+64);
|
1473
|
+
vec_xst(t6, 0, vecOffset+80);
|
1474
|
+
vec_xst(t7, 0, vecOffset+96);
|
1475
|
+
vec_xst(t8, 0, vecOffset+112);
|
1476
|
+
aoffset1 += lda;
|
1477
|
+
aoffset2 += lda;
|
1478
|
+
aoffset3 += lda;
|
1479
|
+
vecOffset += 128;
|
1480
|
+
i--;
|
1481
|
+
} while(i > 0);
|
1482
|
+
}
|
1483
|
+
}
|
1484
|
+
}
|
1485
|
+
|
1486
|
+
template<typename VA, typename VB>
|
1487
|
+
void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
1488
|
+
int64_t i, j;
|
1489
|
+
TB *aoffset = NULL;
|
1490
|
+
VA *vecOffset = NULL;
|
1491
|
+
TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
1492
|
+
TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
1102
1493
|
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
1103
1494
|
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
|
1104
1495
|
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
|
@@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC {
|
|
1111
1502
|
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
1112
1503
|
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
1113
1504
|
|
1114
|
-
aoffset = const_cast<
|
1505
|
+
aoffset = const_cast<TB*>(a);
|
1115
1506
|
vecOffset = vec;
|
1116
1507
|
j = (rows >> 3);
|
1117
1508
|
if (j > 0) {
|
1118
1509
|
do {
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1510
|
+
aoffset1 = aoffset;
|
1511
|
+
aoffset2 = aoffset1 + lda;
|
1512
|
+
aoffset3 = aoffset2 + lda;
|
1513
|
+
aoffset4 = aoffset3 + lda;
|
1514
|
+
aoffset5 = aoffset4 + lda;
|
1515
|
+
aoffset6 = aoffset5 + lda;
|
1516
|
+
aoffset7 = aoffset6 + lda;
|
1517
|
+
aoffset8 = aoffset7 + lda;
|
1518
|
+
aoffset += 8 * lda;
|
1128
1519
|
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1520
|
+
i = (cols >> 3);
|
1521
|
+
if (i > 0) {
|
1522
|
+
do {
|
1132
1523
|
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
1133
1524
|
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
1134
1525
|
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
@@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC {
|
|
1156
1547
|
t7 = vec_perm(t2, t4, swiz3);
|
1157
1548
|
t8 = vec_perm(t2, t4, swiz4);
|
1158
1549
|
if (flip == true) {
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1550
|
+
t5 = vec_xor(t5, xor_vector);
|
1551
|
+
t6 = vec_xor(t6, xor_vector);
|
1552
|
+
t7 = vec_xor(t7, xor_vector);
|
1553
|
+
t8 = vec_xor(t8, xor_vector);
|
1163
1554
|
}
|
1164
1555
|
vec_xst(t5, 0, vecOffset);
|
1165
1556
|
vec_xst(t6, 0, vecOffset+16);
|
@@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC {
|
|
1175
1566
|
t7 = vec_perm(t2, t4, swiz3);
|
1176
1567
|
t8 = vec_perm(t2, t4, swiz4);
|
1177
1568
|
if (flip == true) {
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1569
|
+
t5 = vec_xor(t5, xor_vector);
|
1570
|
+
t6 = vec_xor(t6, xor_vector);
|
1571
|
+
t7 = vec_xor(t7, xor_vector);
|
1572
|
+
t8 = vec_xor(t8, xor_vector);
|
1182
1573
|
}
|
1183
1574
|
vec_xst(t5, 0, vecOffset+64);
|
1184
1575
|
vec_xst(t6, 0, vecOffset+80);
|
@@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC {
|
|
1194
1585
|
t7 = vec_perm(t2, t4, swiz3);
|
1195
1586
|
t8 = vec_perm(t2, t4, swiz4);
|
1196
1587
|
if (flip == true) {
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1588
|
+
t5 = vec_xor(t5, xor_vector);
|
1589
|
+
t6 = vec_xor(t6, xor_vector);
|
1590
|
+
t7 = vec_xor(t7, xor_vector);
|
1591
|
+
t8 = vec_xor(t8, xor_vector);
|
1201
1592
|
}
|
1202
1593
|
vec_xst(t5, 0, vecOffset+128);
|
1203
1594
|
vec_xst(t6, 0, vecOffset+144);
|
@@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC {
|
|
1213
1604
|
t7 = vec_perm(t2, t4, swiz3);
|
1214
1605
|
t8 = vec_perm(t2, t4, swiz4);
|
1215
1606
|
if (flip == true) {
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1607
|
+
t5 = vec_xor(t5, xor_vector);
|
1608
|
+
t6 = vec_xor(t6, xor_vector);
|
1609
|
+
t7 = vec_xor(t7, xor_vector);
|
1610
|
+
t8 = vec_xor(t8, xor_vector);
|
1220
1611
|
}
|
1221
1612
|
vec_xst(t5, 0, vecOffset+192);
|
1222
1613
|
vec_xst(t6, 0, vecOffset+208);
|
@@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC {
|
|
1240
1631
|
}
|
1241
1632
|
|
1242
1633
|
if (rows & 4) {
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1634
|
+
aoffset1 = aoffset;
|
1635
|
+
aoffset2 = aoffset1 + lda;
|
1636
|
+
aoffset3 = aoffset2 + lda;
|
1637
|
+
aoffset4 = aoffset3 + lda;
|
1638
|
+
aoffset += 4 * lda;
|
1248
1639
|
|
1249
1640
|
i = (cols >> 3);
|
1250
1641
|
if (i > 0) {
|
@@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC {
|
|
1311
1702
|
aoffset2 = aoffset1 + lda;
|
1312
1703
|
aoffset3 = aoffset2 + lda;
|
1313
1704
|
i = (cols >> 3);
|
1314
|
-
|
1705
|
+
if (i > 0) {
|
1315
1706
|
do {
|
1316
1707
|
switch(rows) {
|
1317
1708
|
case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
@@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC {
|
|
1527
1918
|
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
1528
1919
|
vec_t vec_A[8], vec_B[16] = {0};
|
1529
1920
|
acc_t acc_0, acc_1;
|
1530
|
-
std::array<int, 4> comparray;
|
1921
|
+
std::array<int, 4> comparray {};
|
1531
1922
|
vector float fin_res[8] = {0};
|
1532
1923
|
vector float vs[8] = {0};
|
1924
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
1533
1925
|
for (int l = 0; l < k; l++) {
|
1534
1926
|
__builtin_mma_xxsetaccz(&acc_0);
|
1535
1927
|
__builtin_mma_xxsetaccz(&acc_1);
|
1536
|
-
|
1928
|
+
if (std::is_same_v<TA, block_q4_0>) {
|
1929
|
+
packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
1930
|
+
} else {
|
1931
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
1932
|
+
}
|
1537
1933
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
1538
1934
|
for(int x = 0; x < 8; x++) {
|
1539
1935
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
@@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC {
|
|
1545
1941
|
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
1546
1942
|
}
|
1547
1943
|
}
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1944
|
+
if (!isAblock_q4) {
|
1945
|
+
auto aoffset = A+(ii*lda)+l;
|
1946
|
+
for (int i = 0; i < 4; i++) {
|
1947
|
+
comparray[i] = 0;
|
1948
|
+
int ca = 0;
|
1949
|
+
auto *at = aoffset->qs;
|
1950
|
+
for (int j = 0; j < 32; j++)
|
1951
|
+
ca += (int)*at++;
|
1952
|
+
comparray[i] = ca;
|
1953
|
+
aoffset += lda;
|
1954
|
+
}
|
1557
1955
|
}
|
1558
1956
|
compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
|
1559
1957
|
compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
|
@@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC {
|
|
1565
1963
|
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
1566
1964
|
vec_t vec_A[16], vec_B[8] = {0};
|
1567
1965
|
acc_t acc_0, acc_1;
|
1568
|
-
std::array<int, 8> comparray;
|
1966
|
+
std::array<int, 8> comparray {};
|
1569
1967
|
vector float fin_res[8] = {0};
|
1570
1968
|
vector float vs[8] = {0};
|
1969
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
1571
1970
|
for (int l = 0; l < k; l++) {
|
1572
1971
|
__builtin_mma_xxsetaccz(&acc_0);
|
1573
1972
|
__builtin_mma_xxsetaccz(&acc_1);
|
1574
|
-
|
1973
|
+
if (std::is_same_v<TA, block_q4_0>) {
|
1974
|
+
packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
1975
|
+
} else {
|
1976
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
1977
|
+
}
|
1575
1978
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
1576
1979
|
for(int x = 0; x < 8; x++) {
|
1577
1980
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
@@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC {
|
|
1582
1985
|
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
1583
1986
|
}
|
1584
1987
|
}
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1988
|
+
if (!isAblock_q4) {
|
1989
|
+
auto aoffset = A+(ii*lda)+l;
|
1990
|
+
for (int i = 0; i < 8; i++) {
|
1991
|
+
comparray[i] = 0;
|
1992
|
+
int ca = 0;
|
1993
|
+
auto *at = aoffset->qs;
|
1994
|
+
for (int j = 0; j < 32; j++)
|
1995
|
+
ca += (int)*at++;
|
1996
|
+
comparray[i] = ca;
|
1997
|
+
aoffset += lda;
|
1998
|
+
}
|
1594
1999
|
}
|
1595
2000
|
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
1596
2001
|
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
@@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC {
|
|
1602
2007
|
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
1603
2008
|
vec_t vec_A[16], vec_B[16] = {0};
|
1604
2009
|
acc_t acc_0, acc_1, acc_2, acc_3;
|
1605
|
-
std::array<int, 8> comparray;
|
2010
|
+
std::array<int, 8> comparray {};
|
1606
2011
|
vector float fin_res[16] = {0};
|
1607
2012
|
vector float vs[16] = {0};
|
2013
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
1608
2014
|
for (int l = 0; l < k; l++) {
|
1609
2015
|
__builtin_mma_xxsetaccz(&acc_0);
|
1610
2016
|
__builtin_mma_xxsetaccz(&acc_1);
|
1611
2017
|
__builtin_mma_xxsetaccz(&acc_2);
|
1612
2018
|
__builtin_mma_xxsetaccz(&acc_3);
|
1613
|
-
|
2019
|
+
if (std::is_same_v<TA, block_q4_0>) {
|
2020
|
+
packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
2021
|
+
} else {
|
2022
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
2023
|
+
}
|
1614
2024
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
1615
2025
|
for(int x = 0; x < 8; x++) {
|
1616
2026
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
@@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC {
|
|
1624
2034
|
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
1625
2035
|
}
|
1626
2036
|
}
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1632
|
-
|
1633
|
-
|
1634
|
-
|
1635
|
-
|
2037
|
+
if (!isAblock_q4) {
|
2038
|
+
auto aoffset = A+(ii*lda)+l;
|
2039
|
+
for (int i = 0; i < 8; i++) {
|
2040
|
+
comparray[i] = 0;
|
2041
|
+
int ca = 0;
|
2042
|
+
auto *at = aoffset->qs;
|
2043
|
+
for (int j = 0; j < 32; j++)
|
2044
|
+
ca += (int)*at++;
|
2045
|
+
comparray[i] = ca;
|
2046
|
+
aoffset += lda;
|
2047
|
+
}
|
1636
2048
|
}
|
1637
2049
|
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
1638
2050
|
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
@@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC {
|
|
1653
2065
|
int64_t duty = (tiles + nth - 1) / nth;
|
1654
2066
|
int64_t start = duty * ith;
|
1655
2067
|
int64_t end = start + duty;
|
1656
|
-
vec_t vec_A[8], vec_B[8] = {0};
|
2068
|
+
vec_t vec_A[8] = {0}, vec_B[8] = {0};
|
1657
2069
|
vector signed int vec_C[4];
|
1658
2070
|
acc_t acc_0;
|
2071
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
1659
2072
|
|
1660
2073
|
if (end > tiles)
|
1661
2074
|
end = tiles;
|
1662
2075
|
for (int64_t job = start; job < end; ++job) {
|
1663
2076
|
int64_t ii = m0 + job / xtiles * RM;
|
1664
2077
|
int64_t jj = n0 + job % xtiles * RN;
|
1665
|
-
std::array<int,
|
2078
|
+
std::array<int, 4> comparray{};
|
1666
2079
|
vector float res[4] = {0};
|
1667
2080
|
vector float fin_res[4] = {0};
|
1668
2081
|
vector float vs[4] = {0};
|
@@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC {
|
|
1673
2086
|
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
1674
2087
|
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
1675
2088
|
__builtin_mma_xxsetaccz(&acc_0);
|
1676
|
-
|
2089
|
+
if (isAblock_q4) {
|
2090
|
+
packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
|
2091
|
+
} else {
|
2092
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
2093
|
+
}
|
1677
2094
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
1678
2095
|
for(int x = 0; x < 8; x+=4) {
|
1679
2096
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
@@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC {
|
|
1687
2104
|
}
|
1688
2105
|
}
|
1689
2106
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
1690
|
-
|
1691
|
-
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
2107
|
+
if (!isAblock_q4) {
|
2108
|
+
auto aoffset = A+(ii*lda)+l;
|
2109
|
+
for (int i = 0; i < RM; i++) {
|
2110
|
+
comparray[i] = 0;
|
2111
|
+
int ca = 0;
|
2112
|
+
auto *at = aoffset->qs;
|
2113
|
+
for (int j = 0; j < 32; j++)
|
2114
|
+
ca += (int)*at++;
|
2115
|
+
comparray[i] = ca;
|
2116
|
+
aoffset += lda;
|
2117
|
+
}
|
1699
2118
|
}
|
1700
|
-
|
1701
2119
|
for (int i = 0; i < RM; i++) {
|
1702
2120
|
CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
|
1703
2121
|
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
@@ -2013,6 +2431,7 @@ class tinyBLAS_PPC {
|
|
2013
2431
|
}
|
2014
2432
|
}
|
2015
2433
|
}
|
2434
|
+
|
2016
2435
|
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
2017
2436
|
vec_t vec_A[4], vec_B[4], vec_C[4];
|
2018
2437
|
acc_t acc_0;
|
@@ -2259,15 +2678,27 @@ class tinyBLAS_PPC {
|
|
2259
2678
|
vec_t vec_C[4];
|
2260
2679
|
acc_t acc_0;
|
2261
2680
|
__builtin_mma_xxsetaccz(&acc_0);
|
2262
|
-
vec_t vec_A[4], vec_B[4];
|
2681
|
+
vec_t vec_A[4] {0}, vec_B[4] = {0};
|
2263
2682
|
for (int l=0; l<k; l+=4) {
|
2264
|
-
|
2683
|
+
/* 'GEMV Forwarding' concept is used in first two conditional loops.
|
2684
|
+
* when one of the matrix has a single row/column, the elements are
|
2685
|
+
* broadcasted, instead of using packing routine to prepack the
|
2686
|
+
* matrix elements.
|
2687
|
+
*/
|
2688
|
+
if (RM == 1) {
|
2265
2689
|
TA* a = const_cast<TA*>(A+(ii)*lda+l);
|
2266
|
-
packTranspose<vector float>(B+(jj*ldb)+l, ldb,
|
2690
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
2267
2691
|
vec_A[0] = (vec_t)vec_xl(0,a);
|
2268
2692
|
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
|
2269
2693
|
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
|
2270
2694
|
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
|
2695
|
+
} else if (RN == 1) {
|
2696
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
2697
|
+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
|
2698
|
+
vec_B[0] = (vec_t)vec_xl(0,b);
|
2699
|
+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
|
2700
|
+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
|
2701
|
+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
|
2271
2702
|
} else {
|
2272
2703
|
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
2273
2704
|
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
@@ -2371,8 +2802,10 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
|
|
2371
2802
|
assert(params->ith < params->nth);
|
2372
2803
|
|
2373
2804
|
// only enable sgemm for prompt processing
|
2805
|
+
#if !defined(__MMA__)
|
2374
2806
|
if (n < 2)
|
2375
2807
|
return false;
|
2808
|
+
#endif
|
2376
2809
|
|
2377
2810
|
if (Ctype != LM_GGML_TYPE_F32)
|
2378
2811
|
return false;
|
@@ -2503,8 +2936,8 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
|
|
2503
2936
|
params->ith, params->nth};
|
2504
2937
|
tb.matmul(m, n);
|
2505
2938
|
return true;
|
2506
|
-
|
2507
2939
|
#elif defined(__MMA__)
|
2940
|
+
//TO-DO: Remove this condition once gemv forwarding is enabled.
|
2508
2941
|
if (n < 8 && n != 4)
|
2509
2942
|
return false;
|
2510
2943
|
if (m < 8 && m != 4)
|
@@ -2516,7 +2949,6 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
|
|
2516
2949
|
params->ith, params->nth};
|
2517
2950
|
tb.matmul(m, n);
|
2518
2951
|
return true;
|
2519
|
-
|
2520
2952
|
#else
|
2521
2953
|
return false;
|
2522
2954
|
#endif
|
@@ -2541,6 +2973,19 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
|
|
2541
2973
|
params->ith, params->nth};
|
2542
2974
|
tb.matmul(m, n);
|
2543
2975
|
return true;
|
2976
|
+
#elif defined(__MMA__)
|
2977
|
+
//TO-DO: Remove this condition once gemv forwarding is enabled.
|
2978
|
+
if (n < 8 && n != 4)
|
2979
|
+
return false;
|
2980
|
+
if (m < 8 && m != 4)
|
2981
|
+
return false;
|
2982
|
+
tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
|
2983
|
+
k, (const block_q4_0 *)A, lda,
|
2984
|
+
(const block_q8_0 *)B, ldb,
|
2985
|
+
(float *)C, ldc,
|
2986
|
+
params->ith, params->nth};
|
2987
|
+
tb.matmul(m, n);
|
2988
|
+
return true;
|
2544
2989
|
#else
|
2545
2990
|
return false;
|
2546
2991
|
#endif
|