cui-llama.rn 1.3.5 → 1.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +22 -1
- package/android/src/main/CMakeLists.txt +25 -20
- package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
- package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
- package/android/src/main/jni-utils.h +94 -0
- package/android/src/main/jni.cpp +108 -37
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
- package/cpp/common.cpp +1982 -1965
- package/cpp/common.h +665 -657
- package/cpp/ggml-backend-reg.cpp +5 -0
- package/cpp/ggml-backend.cpp +5 -2
- package/cpp/ggml-cpp.h +1 -0
- package/cpp/ggml-cpu-aarch64.cpp +6 -1
- package/cpp/ggml-cpu-quants.c +5 -1
- package/cpp/ggml-cpu.c +14122 -14122
- package/cpp/ggml-cpu.cpp +627 -627
- package/cpp/ggml-impl.h +11 -16
- package/cpp/ggml-metal-impl.h +288 -0
- package/cpp/ggml-metal.m +2 -2
- package/cpp/ggml-opt.cpp +854 -0
- package/cpp/ggml-opt.h +216 -0
- package/cpp/ggml.c +0 -1276
- package/cpp/ggml.h +0 -140
- package/cpp/gguf.cpp +1325 -0
- package/cpp/gguf.h +202 -0
- package/cpp/llama-adapter.cpp +346 -0
- package/cpp/llama-adapter.h +73 -0
- package/cpp/llama-arch.cpp +1434 -0
- package/cpp/llama-arch.h +395 -0
- package/cpp/llama-batch.cpp +368 -0
- package/cpp/llama-batch.h +88 -0
- package/cpp/llama-chat.cpp +567 -0
- package/cpp/llama-chat.h +51 -0
- package/cpp/llama-context.cpp +1771 -0
- package/cpp/llama-context.h +128 -0
- package/cpp/llama-cparams.cpp +1 -0
- package/cpp/llama-cparams.h +37 -0
- package/cpp/llama-cpp.h +30 -0
- package/cpp/llama-grammar.cpp +1 -0
- package/cpp/llama-grammar.h +3 -1
- package/cpp/llama-hparams.cpp +71 -0
- package/cpp/llama-hparams.h +140 -0
- package/cpp/llama-impl.cpp +167 -0
- package/cpp/llama-impl.h +16 -136
- package/cpp/llama-kv-cache.cpp +718 -0
- package/cpp/llama-kv-cache.h +218 -0
- package/cpp/llama-mmap.cpp +589 -0
- package/cpp/llama-mmap.h +67 -0
- package/cpp/llama-model-loader.cpp +1011 -0
- package/cpp/llama-model-loader.h +158 -0
- package/cpp/llama-model.cpp +2202 -0
- package/cpp/llama-model.h +391 -0
- package/cpp/llama-sampling.cpp +117 -4
- package/cpp/llama-vocab.cpp +21 -28
- package/cpp/llama-vocab.h +13 -1
- package/cpp/llama.cpp +12547 -23528
- package/cpp/llama.h +31 -6
- package/cpp/rn-llama.hpp +90 -87
- package/cpp/sgemm.cpp +776 -70
- package/cpp/sgemm.h +14 -14
- package/cpp/unicode.cpp +6 -0
- package/ios/RNLlama.mm +47 -0
- package/ios/RNLlamaContext.h +3 -1
- package/ios/RNLlamaContext.mm +71 -14
- package/jest/mock.js +15 -3
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +33 -37
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +31 -35
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +26 -6
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +21 -36
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +4 -18
- package/package.json +2 -3
- package/src/NativeRNLlama.ts +32 -13
- package/src/index.ts +52 -47
package/cpp/sgemm.cpp
CHANGED
@@ -54,6 +54,7 @@
|
|
54
54
|
#include "ggml-quants.h"
|
55
55
|
|
56
56
|
#include <atomic>
|
57
|
+
#include <array>
|
57
58
|
|
58
59
|
#ifdef _MSC_VER
|
59
60
|
#define NOINLINE __declspec(noinline)
|
@@ -1000,8 +1001,10 @@ class tinyBLAS_Q0_AVX {
|
|
1000
1001
|
|
1001
1002
|
inline __m256 updot(__m256i u, __m256i s) {
|
1002
1003
|
__m256i res;
|
1003
|
-
#if defined(
|
1004
|
+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
1004
1005
|
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
|
1006
|
+
#elif defined(__AVXVNNI__)
|
1007
|
+
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
|
1005
1008
|
#else
|
1006
1009
|
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
1007
1010
|
#endif
|
@@ -1049,6 +1052,704 @@ class tinyBLAS_Q0_AVX {
|
|
1049
1052
|
} \
|
1050
1053
|
} \
|
1051
1054
|
|
1055
|
+
template <typename TA, typename TB, typename TC>
|
1056
|
+
class tinyBLAS_Q0_PPC {
|
1057
|
+
public:
|
1058
|
+
tinyBLAS_Q0_PPC(int64_t k,
|
1059
|
+
const TA *A, int64_t lda,
|
1060
|
+
const TB *B, int64_t ldb,
|
1061
|
+
TC *C, int64_t ldc,
|
1062
|
+
int ith, int nth)
|
1063
|
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
1064
|
+
}
|
1065
|
+
|
1066
|
+
void matmul(int64_t m, int64_t n) {
|
1067
|
+
mnpack(0, m, 0, n);
|
1068
|
+
}
|
1069
|
+
|
1070
|
+
private:
|
1071
|
+
|
1072
|
+
template<int RM, int RN>
|
1073
|
+
inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
|
1074
|
+
for (int I = 0; I < RM; I++) {
|
1075
|
+
for (int J = 0; J < RN; J++) {
|
1076
|
+
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
|
1077
|
+
}
|
1078
|
+
}
|
1079
|
+
}
|
1080
|
+
|
1081
|
+
template<int size>
|
1082
|
+
inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
|
1083
|
+
vector signed int vec_C[4];
|
1084
|
+
vector float CA[4] = {0};
|
1085
|
+
vector float res[4] = {0};
|
1086
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
1087
|
+
for (int i = 0; i < 4; i++) {
|
1088
|
+
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
|
1089
|
+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
1090
|
+
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
1091
|
+
}
|
1092
|
+
}
|
1093
|
+
|
1094
|
+
template<typename VA, typename VB>
|
1095
|
+
void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
1096
|
+
int64_t i, j;
|
1097
|
+
TA *aoffset = NULL;
|
1098
|
+
VA *vecOffset = NULL;
|
1099
|
+
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
1100
|
+
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
1101
|
+
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
1102
|
+
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
|
1103
|
+
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
|
1104
|
+
VB t1, t2, t3, t4, t5, t6, t7, t8;
|
1105
|
+
vector unsigned char xor_vector;
|
1106
|
+
uint8_t flip_vec = 0x80;
|
1107
|
+
xor_vector = vec_splats(flip_vec);
|
1108
|
+
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
1109
|
+
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
1110
|
+
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
1111
|
+
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
1112
|
+
|
1113
|
+
aoffset = const_cast<TA*>(a);
|
1114
|
+
vecOffset = vec;
|
1115
|
+
j = (rows >> 3);
|
1116
|
+
if (j > 0) {
|
1117
|
+
do {
|
1118
|
+
aoffset1 = aoffset;
|
1119
|
+
aoffset2 = aoffset1 + lda;
|
1120
|
+
aoffset3 = aoffset2 + lda;
|
1121
|
+
aoffset4 = aoffset3 + lda;
|
1122
|
+
aoffset5 = aoffset4 + lda;
|
1123
|
+
aoffset6 = aoffset5 + lda;
|
1124
|
+
aoffset7 = aoffset6 + lda;
|
1125
|
+
aoffset8 = aoffset7 + lda;
|
1126
|
+
aoffset += 8 * lda;
|
1127
|
+
|
1128
|
+
i = (cols >> 3);
|
1129
|
+
if (i > 0) {
|
1130
|
+
do {
|
1131
|
+
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
1132
|
+
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
1133
|
+
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
1134
|
+
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
|
1135
|
+
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
|
1136
|
+
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
|
1137
|
+
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
|
1138
|
+
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
|
1139
|
+
|
1140
|
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
1141
|
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
1142
|
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
1143
|
+
__builtin_vsx_disassemble_pair(c4, &C4);
|
1144
|
+
__builtin_vsx_disassemble_pair(c5, &C5);
|
1145
|
+
__builtin_vsx_disassemble_pair(c6, &C6);
|
1146
|
+
__builtin_vsx_disassemble_pair(c7, &C7);
|
1147
|
+
__builtin_vsx_disassemble_pair(c8, &C8);
|
1148
|
+
|
1149
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
1150
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
1151
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
1152
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
1153
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1154
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1155
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1156
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1157
|
+
if (flip == true) {
|
1158
|
+
t5 = vec_xor(t5, xor_vector);
|
1159
|
+
t6 = vec_xor(t6, xor_vector);
|
1160
|
+
t7 = vec_xor(t7, xor_vector);
|
1161
|
+
t8 = vec_xor(t8, xor_vector);
|
1162
|
+
}
|
1163
|
+
vec_xst(t5, 0, vecOffset);
|
1164
|
+
vec_xst(t6, 0, vecOffset+16);
|
1165
|
+
vec_xst(t7, 0, vecOffset+32);
|
1166
|
+
vec_xst(t8, 0, vecOffset+48);
|
1167
|
+
|
1168
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
1169
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
1170
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
1171
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
1172
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1173
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1174
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1175
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1176
|
+
if (flip == true) {
|
1177
|
+
t5 = vec_xor(t5, xor_vector);
|
1178
|
+
t6 = vec_xor(t6, xor_vector);
|
1179
|
+
t7 = vec_xor(t7, xor_vector);
|
1180
|
+
t8 = vec_xor(t8, xor_vector);
|
1181
|
+
}
|
1182
|
+
vec_xst(t5, 0, vecOffset+64);
|
1183
|
+
vec_xst(t6, 0, vecOffset+80);
|
1184
|
+
vec_xst(t7, 0, vecOffset+96);
|
1185
|
+
vec_xst(t8, 0, vecOffset+112);
|
1186
|
+
|
1187
|
+
t1 = vec_perm(c5[0], c6[0], swiz1);
|
1188
|
+
t2 = vec_perm(c5[0], c6[0], swiz2);
|
1189
|
+
t3 = vec_perm(c7[0], c8[0], swiz1);
|
1190
|
+
t4 = vec_perm(c7[0], c8[0], swiz2);
|
1191
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1192
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1193
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1194
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1195
|
+
if (flip == true) {
|
1196
|
+
t5 = vec_xor(t5, xor_vector);
|
1197
|
+
t6 = vec_xor(t6, xor_vector);
|
1198
|
+
t7 = vec_xor(t7, xor_vector);
|
1199
|
+
t8 = vec_xor(t8, xor_vector);
|
1200
|
+
}
|
1201
|
+
vec_xst(t5, 0, vecOffset+128);
|
1202
|
+
vec_xst(t6, 0, vecOffset+144);
|
1203
|
+
vec_xst(t7, 0, vecOffset+160);
|
1204
|
+
vec_xst(t8, 0, vecOffset+176);
|
1205
|
+
|
1206
|
+
t1 = vec_perm(c5[1], c6[1], swiz1);
|
1207
|
+
t2 = vec_perm(c5[1], c6[1], swiz2);
|
1208
|
+
t3 = vec_perm(c7[1], c8[1], swiz1);
|
1209
|
+
t4 = vec_perm(c7[1], c8[1], swiz2);
|
1210
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1211
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1212
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1213
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1214
|
+
if (flip == true) {
|
1215
|
+
t5 = vec_xor(t5, xor_vector);
|
1216
|
+
t6 = vec_xor(t6, xor_vector);
|
1217
|
+
t7 = vec_xor(t7, xor_vector);
|
1218
|
+
t8 = vec_xor(t8, xor_vector);
|
1219
|
+
}
|
1220
|
+
vec_xst(t5, 0, vecOffset+192);
|
1221
|
+
vec_xst(t6, 0, vecOffset+208);
|
1222
|
+
vec_xst(t7, 0, vecOffset+224);
|
1223
|
+
vec_xst(t8, 0, vecOffset+240);
|
1224
|
+
|
1225
|
+
aoffset1 += lda;
|
1226
|
+
aoffset2 += lda;
|
1227
|
+
aoffset3 += lda;
|
1228
|
+
aoffset4 += lda;
|
1229
|
+
aoffset5 += lda;
|
1230
|
+
aoffset6 += lda;
|
1231
|
+
aoffset7 += lda;
|
1232
|
+
aoffset8 += lda;
|
1233
|
+
vecOffset += 256;
|
1234
|
+
i--;
|
1235
|
+
} while(i > 0);
|
1236
|
+
}
|
1237
|
+
j--;
|
1238
|
+
} while(j > 0);
|
1239
|
+
}
|
1240
|
+
|
1241
|
+
if (rows & 4) {
|
1242
|
+
aoffset1 = aoffset;
|
1243
|
+
aoffset2 = aoffset1 + lda;
|
1244
|
+
aoffset3 = aoffset2 + lda;
|
1245
|
+
aoffset4 = aoffset3 + lda;
|
1246
|
+
aoffset += 4 * lda;
|
1247
|
+
|
1248
|
+
i = (cols >> 3);
|
1249
|
+
if (i > 0) {
|
1250
|
+
do {
|
1251
|
+
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
1252
|
+
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
1253
|
+
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
1254
|
+
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
|
1255
|
+
|
1256
|
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
1257
|
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
1258
|
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
1259
|
+
__builtin_vsx_disassemble_pair(c4, &C4);
|
1260
|
+
|
1261
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
1262
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
1263
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
1264
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
1265
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1266
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1267
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1268
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1269
|
+
if (flip == true) {
|
1270
|
+
t5 = vec_xor(t5, xor_vector);
|
1271
|
+
t6 = vec_xor(t6, xor_vector);
|
1272
|
+
t7 = vec_xor(t7, xor_vector);
|
1273
|
+
t8 = vec_xor(t8, xor_vector);
|
1274
|
+
}
|
1275
|
+
vec_xst(t5, 0, vecOffset);
|
1276
|
+
vec_xst(t6, 0, vecOffset+16);
|
1277
|
+
vec_xst(t7, 0, vecOffset+32);
|
1278
|
+
vec_xst(t8, 0, vecOffset+48);
|
1279
|
+
|
1280
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
1281
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
1282
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
1283
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
1284
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1285
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1286
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1287
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1288
|
+
if (flip == true) {
|
1289
|
+
t5 = vec_xor(t5, xor_vector);
|
1290
|
+
t6 = vec_xor(t6, xor_vector);
|
1291
|
+
t7 = vec_xor(t7, xor_vector);
|
1292
|
+
t8 = vec_xor(t8, xor_vector);
|
1293
|
+
}
|
1294
|
+
vec_xst(t5, 0, vecOffset+64);
|
1295
|
+
vec_xst(t6, 0, vecOffset+80);
|
1296
|
+
vec_xst(t7, 0, vecOffset+96);
|
1297
|
+
vec_xst(t8, 0, vecOffset+112);
|
1298
|
+
|
1299
|
+
aoffset1 += lda;
|
1300
|
+
aoffset2 += lda;
|
1301
|
+
aoffset3 += lda;
|
1302
|
+
aoffset4 += lda;
|
1303
|
+
vecOffset += 128;
|
1304
|
+
i--;
|
1305
|
+
} while(i > 0);
|
1306
|
+
}
|
1307
|
+
}
|
1308
|
+
if (rows & 3) {
|
1309
|
+
aoffset1 = aoffset;
|
1310
|
+
aoffset2 = aoffset1 + lda;
|
1311
|
+
aoffset3 = aoffset2 + lda;
|
1312
|
+
i = (cols >> 3);
|
1313
|
+
if (i > 0) {
|
1314
|
+
do {
|
1315
|
+
switch(rows) {
|
1316
|
+
case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
1317
|
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
1318
|
+
case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
1319
|
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
1320
|
+
case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
1321
|
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
1322
|
+
break;
|
1323
|
+
}
|
1324
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
1325
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
1326
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
1327
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
1328
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1329
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1330
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1331
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1332
|
+
if (flip == true) {
|
1333
|
+
t5 = vec_xor(t5, xor_vector);
|
1334
|
+
t6 = vec_xor(t6, xor_vector);
|
1335
|
+
t7 = vec_xor(t7, xor_vector);
|
1336
|
+
t8 = vec_xor(t8, xor_vector);
|
1337
|
+
}
|
1338
|
+
vec_xst(t5, 0, vecOffset);
|
1339
|
+
vec_xst(t6, 0, vecOffset+16);
|
1340
|
+
vec_xst(t7, 0, vecOffset+32);
|
1341
|
+
vec_xst(t8, 0, vecOffset+48);
|
1342
|
+
|
1343
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
1344
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
1345
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
1346
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
1347
|
+
t5 = vec_perm(t1, t3, swiz3);
|
1348
|
+
t6 = vec_perm(t1, t3, swiz4);
|
1349
|
+
t7 = vec_perm(t2, t4, swiz3);
|
1350
|
+
t8 = vec_perm(t2, t4, swiz4);
|
1351
|
+
if (flip == true) {
|
1352
|
+
t5 = vec_xor(t5, xor_vector);
|
1353
|
+
t6 = vec_xor(t6, xor_vector);
|
1354
|
+
t7 = vec_xor(t7, xor_vector);
|
1355
|
+
t8 = vec_xor(t8, xor_vector);
|
1356
|
+
}
|
1357
|
+
vec_xst(t5, 0, vecOffset+64);
|
1358
|
+
vec_xst(t6, 0, vecOffset+80);
|
1359
|
+
vec_xst(t7, 0, vecOffset+96);
|
1360
|
+
vec_xst(t8, 0, vecOffset+112);
|
1361
|
+
|
1362
|
+
aoffset1 += lda;
|
1363
|
+
aoffset2 += lda;
|
1364
|
+
aoffset3 += lda;
|
1365
|
+
vecOffset += 128;
|
1366
|
+
i--;
|
1367
|
+
} while(i > 0);
|
1368
|
+
}
|
1369
|
+
}
|
1370
|
+
}
|
1371
|
+
|
1372
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
1373
|
+
int64_t mc, nc, mp, np;
|
1374
|
+
int m_rem = MIN(m - m0, 8);
|
1375
|
+
int n_rem = MIN(n - n0, 8);
|
1376
|
+
// TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
|
1377
|
+
// issues. After resolving them, below code will be enabled.
|
1378
|
+
/*if (m_rem >= 16 && n_rem >= 8) {
|
1379
|
+
mc = 16;
|
1380
|
+
nc = 8;
|
1381
|
+
gemm<16,8>(m0, m, n0, n);
|
1382
|
+
} else if(m_rem >= 8 && n_rem >= 16) {
|
1383
|
+
mc = 8;
|
1384
|
+
nc = 16;
|
1385
|
+
gemm<8,16>(m0, m, n0, n);
|
1386
|
+
}*/
|
1387
|
+
if (m_rem >= 8 && n_rem >= 8) {
|
1388
|
+
mc = 8;
|
1389
|
+
nc = 8;
|
1390
|
+
gemm<8,8>(m0, m, n0, n);
|
1391
|
+
} else if (m_rem >= 4 && n_rem >= 8) {
|
1392
|
+
mc = 4;
|
1393
|
+
nc = 8;
|
1394
|
+
gemm<4,8>(m0, m, n0, n);
|
1395
|
+
} else if (m_rem >= 8 && n_rem >= 4) {
|
1396
|
+
mc = 8;
|
1397
|
+
nc = 4;
|
1398
|
+
gemm<8,4>(m0, m, n0, n);
|
1399
|
+
} else if (m_rem >= 4 && n_rem >= 4) {
|
1400
|
+
mc = 4;
|
1401
|
+
nc = 4;
|
1402
|
+
gemm_small<4, 4>(m0, m, n0, n);
|
1403
|
+
} else if ((m_rem < 4) && (n_rem > 4)) {
|
1404
|
+
nc = 4;
|
1405
|
+
switch(m_rem) {
|
1406
|
+
case 1:
|
1407
|
+
mc = 1;
|
1408
|
+
gemm_small<1, 4>(m0, m, n0, n);
|
1409
|
+
break;
|
1410
|
+
case 2:
|
1411
|
+
mc = 2;
|
1412
|
+
gemm_small<2, 4>(m0, m, n0, n);
|
1413
|
+
break;
|
1414
|
+
case 3:
|
1415
|
+
mc = 3;
|
1416
|
+
gemm_small<3, 4>(m0, m, n0, n);
|
1417
|
+
break;
|
1418
|
+
default:
|
1419
|
+
return;
|
1420
|
+
}
|
1421
|
+
} else if ((m_rem > 4) && (n_rem < 4)) {
|
1422
|
+
mc = 4;
|
1423
|
+
switch(n_rem) {
|
1424
|
+
case 1:
|
1425
|
+
nc = 1;
|
1426
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
1427
|
+
break;
|
1428
|
+
case 2:
|
1429
|
+
nc = 2;
|
1430
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
1431
|
+
break;
|
1432
|
+
case 3:
|
1433
|
+
nc = 3;
|
1434
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
1435
|
+
break;
|
1436
|
+
default:
|
1437
|
+
return;
|
1438
|
+
}
|
1439
|
+
} else {
|
1440
|
+
switch((m_rem << 4) | n_rem) {
|
1441
|
+
case 0x43:
|
1442
|
+
mc = 4;
|
1443
|
+
nc = 3;
|
1444
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
1445
|
+
break;
|
1446
|
+
case 0x42:
|
1447
|
+
mc = 4;
|
1448
|
+
nc = 2;
|
1449
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
1450
|
+
break;
|
1451
|
+
case 0x41:
|
1452
|
+
mc = 4;
|
1453
|
+
nc = 1;
|
1454
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
1455
|
+
break;
|
1456
|
+
case 0x34:
|
1457
|
+
mc = 3;
|
1458
|
+
nc = 4;
|
1459
|
+
gemm_small<3, 4>(m0, m, n0, n);
|
1460
|
+
break;
|
1461
|
+
case 0x33:
|
1462
|
+
mc = 3;
|
1463
|
+
nc = 3;
|
1464
|
+
gemm_small<3, 3>(m0, m, n0, n);
|
1465
|
+
break;
|
1466
|
+
case 0x32:
|
1467
|
+
mc = 3;
|
1468
|
+
nc = 2;
|
1469
|
+
gemm_small<3, 2>(m0, m, n0, n);
|
1470
|
+
break;
|
1471
|
+
case 0x31:
|
1472
|
+
mc = 3;
|
1473
|
+
nc = 1;
|
1474
|
+
gemm_small<3, 1>(m0, m, n0, n);
|
1475
|
+
break;
|
1476
|
+
case 0x24:
|
1477
|
+
mc = 2;
|
1478
|
+
nc = 4;
|
1479
|
+
gemm_small<2, 4>(m0, m, n0, n);
|
1480
|
+
break;
|
1481
|
+
case 0x23:
|
1482
|
+
mc = 2;
|
1483
|
+
nc = 3;
|
1484
|
+
gemm_small<2, 3>(m0, m, n0, n);
|
1485
|
+
break;
|
1486
|
+
case 0x22:
|
1487
|
+
mc = 2;
|
1488
|
+
nc = 2;
|
1489
|
+
gemm_small<2, 2>(m0, m, n0, n);
|
1490
|
+
break;
|
1491
|
+
case 0x21:
|
1492
|
+
mc = 2;
|
1493
|
+
nc = 1;
|
1494
|
+
gemm_small<2, 1>(m0, m, n0, n);
|
1495
|
+
break;
|
1496
|
+
case 0x14:
|
1497
|
+
mc = 1;
|
1498
|
+
nc = 4;
|
1499
|
+
gemm_small<1, 4>(m0, m, n0, n);
|
1500
|
+
break;
|
1501
|
+
case 0x13:
|
1502
|
+
mc = 1;
|
1503
|
+
nc = 3;
|
1504
|
+
gemm_small<1, 3>(m0, m, n0, n);
|
1505
|
+
break;
|
1506
|
+
case 0x12:
|
1507
|
+
mc = 1;
|
1508
|
+
nc = 2;
|
1509
|
+
gemm_small<1, 2>(m0, m, n0, n);
|
1510
|
+
break;
|
1511
|
+
case 0x11:
|
1512
|
+
mc = 1;
|
1513
|
+
nc = 1;
|
1514
|
+
gemm_small<1, 1>(m0, m, n0, n);
|
1515
|
+
break;
|
1516
|
+
default:
|
1517
|
+
return;
|
1518
|
+
}
|
1519
|
+
}
|
1520
|
+
mp = m0 + (m - m0) / mc * mc;
|
1521
|
+
np = n0 + (n - n0) / nc * nc;
|
1522
|
+
mnpack(mp, m, n0, np);
|
1523
|
+
mnpack(m0, m, np, n);
|
1524
|
+
}
|
1525
|
+
|
1526
|
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
1527
|
+
vec_t vec_A[8], vec_B[16] = {0};
|
1528
|
+
acc_t acc_0, acc_1;
|
1529
|
+
std::array<int, 4> comparray;
|
1530
|
+
vector float fin_res[8] = {0};
|
1531
|
+
vector float vs[8] = {0};
|
1532
|
+
for (int l = 0; l < k; l++) {
|
1533
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1534
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
1535
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
1536
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
1537
|
+
for(int x = 0; x < 8; x++) {
|
1538
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
1539
|
+
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
1540
|
+
}
|
1541
|
+
for (int I = 0; I<4; I++) {
|
1542
|
+
for (int J = 0; J<4; J++) {
|
1543
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
1544
|
+
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
1545
|
+
}
|
1546
|
+
}
|
1547
|
+
auto aoffset = A+(ii*lda)+l;
|
1548
|
+
for (int i = 0; i < 4; i++) {
|
1549
|
+
comparray[i] = 0;
|
1550
|
+
int ca = 0;
|
1551
|
+
const int8_t *at = aoffset->qs;
|
1552
|
+
for (int j = 0; j < 32; j++)
|
1553
|
+
ca += (int)*at++;
|
1554
|
+
comparray[i] = ca;
|
1555
|
+
aoffset += lda;
|
1556
|
+
}
|
1557
|
+
compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
|
1558
|
+
compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
|
1559
|
+
}
|
1560
|
+
save_res<4, 4>(ii, jj, 0, fin_res);
|
1561
|
+
save_res<4, 4>(ii, jj+4, 4, fin_res);
|
1562
|
+
}
|
1563
|
+
|
1564
|
+
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
1565
|
+
vec_t vec_A[16], vec_B[8] = {0};
|
1566
|
+
acc_t acc_0, acc_1;
|
1567
|
+
std::array<int, 8> comparray;
|
1568
|
+
vector float fin_res[8] = {0};
|
1569
|
+
vector float vs[8] = {0};
|
1570
|
+
for (int l = 0; l < k; l++) {
|
1571
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1572
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
1573
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
1574
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
1575
|
+
for(int x = 0; x < 8; x++) {
|
1576
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
1577
|
+
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
1578
|
+
}
|
1579
|
+
for (int I = 0; I<8; I++) {
|
1580
|
+
for (int J = 0; J<4; J++) {
|
1581
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
1582
|
+
}
|
1583
|
+
}
|
1584
|
+
auto aoffset = A+(ii*lda)+l;
|
1585
|
+
for (int i = 0; i < 8; i++) {
|
1586
|
+
comparray[i] = 0;
|
1587
|
+
int ca = 0;
|
1588
|
+
const int8_t *at = aoffset->qs;
|
1589
|
+
for (int j = 0; j < 32; j++)
|
1590
|
+
ca += (int)*at++;
|
1591
|
+
comparray[i] = ca;
|
1592
|
+
aoffset += lda;
|
1593
|
+
}
|
1594
|
+
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
1595
|
+
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
1596
|
+
}
|
1597
|
+
save_res<4, 4>(ii, jj, 0, fin_res);
|
1598
|
+
save_res<4, 4>(ii+4, jj, 4, fin_res);
|
1599
|
+
}
|
1600
|
+
|
1601
|
+
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
1602
|
+
vec_t vec_A[16], vec_B[16] = {0};
|
1603
|
+
acc_t acc_0, acc_1, acc_2, acc_3;
|
1604
|
+
std::array<int, 8> comparray;
|
1605
|
+
vector float fin_res[16] = {0};
|
1606
|
+
vector float vs[16] = {0};
|
1607
|
+
for (int l = 0; l < k; l++) {
|
1608
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1609
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
1610
|
+
__builtin_mma_xxsetaccz(&acc_2);
|
1611
|
+
__builtin_mma_xxsetaccz(&acc_3);
|
1612
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
1613
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
1614
|
+
for(int x = 0; x < 8; x++) {
|
1615
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
1616
|
+
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
1617
|
+
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
1618
|
+
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
1619
|
+
}
|
1620
|
+
for (int I = 0; I<8; I++) {
|
1621
|
+
for (int J = 0; J<4; J++) {
|
1622
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
1623
|
+
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
1624
|
+
}
|
1625
|
+
}
|
1626
|
+
auto aoffset = A+(ii*lda)+l;
|
1627
|
+
for (int i = 0; i < 8; i++) {
|
1628
|
+
comparray[i] = 0;
|
1629
|
+
int ca = 0;
|
1630
|
+
const int8_t *at = aoffset->qs;
|
1631
|
+
for (int j = 0; j < 32; j++)
|
1632
|
+
ca += (int)*at++;
|
1633
|
+
comparray[i] = ca;
|
1634
|
+
aoffset += lda;
|
1635
|
+
}
|
1636
|
+
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
1637
|
+
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
1638
|
+
compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
|
1639
|
+
compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
|
1640
|
+
}
|
1641
|
+
save_res<4, 4>(ii, jj, 0, fin_res);
|
1642
|
+
save_res<4, 4>(ii+4, jj, 4, fin_res);
|
1643
|
+
save_res<4, 4>(ii, jj+4, 8, fin_res);
|
1644
|
+
save_res<4, 4>(ii+4, jj+4, 12, fin_res);
|
1645
|
+
}
|
1646
|
+
|
1647
|
+
template<int RM, int RN>
|
1648
|
+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
1649
|
+
int64_t ytiles = (m - m0) / RM;
|
1650
|
+
int64_t xtiles = (n - n0) / RN;
|
1651
|
+
int64_t tiles = xtiles * ytiles;
|
1652
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
1653
|
+
int64_t start = duty * ith;
|
1654
|
+
int64_t end = start + duty;
|
1655
|
+
vec_t vec_A[8], vec_B[8] = {0};
|
1656
|
+
vector signed int vec_C[4];
|
1657
|
+
acc_t acc_0;
|
1658
|
+
|
1659
|
+
if (end > tiles)
|
1660
|
+
end = tiles;
|
1661
|
+
for (int64_t job = start; job < end; ++job) {
|
1662
|
+
int64_t ii = m0 + job / xtiles * RM;
|
1663
|
+
int64_t jj = n0 + job % xtiles * RN;
|
1664
|
+
std::array<int, RM> comparray;
|
1665
|
+
vector float res[4] = {0};
|
1666
|
+
vector float fin_res[4] = {0};
|
1667
|
+
vector float vs[4] = {0};
|
1668
|
+
vector float CA[4] = {0};
|
1669
|
+
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
1670
|
+
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
1671
|
+
for (int l = 0; l < k; l++) {
|
1672
|
+
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
1673
|
+
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
1674
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
1675
|
+
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
1676
|
+
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
1677
|
+
for(int x = 0; x < 8; x+=4) {
|
1678
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
1679
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
1680
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
1681
|
+
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
1682
|
+
}
|
1683
|
+
for (int I = 0; I<RM; I++) {
|
1684
|
+
for (int J = 0; J<RN; J++) {
|
1685
|
+
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
1686
|
+
}
|
1687
|
+
}
|
1688
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
1689
|
+
auto aoffset = A+(ii*lda)+l;
|
1690
|
+
for (int i = 0; i < RM; i++) {
|
1691
|
+
comparray[i] = 0;
|
1692
|
+
int ca = 0;
|
1693
|
+
const int8_t *at = aoffset->qs;
|
1694
|
+
for (int j = 0; j < 32; j++)
|
1695
|
+
ca += (int)*at++;
|
1696
|
+
comparray[i] = ca;
|
1697
|
+
aoffset += lda;
|
1698
|
+
}
|
1699
|
+
|
1700
|
+
for (int i = 0; i < RM; i++) {
|
1701
|
+
CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
|
1702
|
+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
1703
|
+
fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
|
1704
|
+
}
|
1705
|
+
}
|
1706
|
+
save_res<RM, RN>(ii, jj, 0, fin_res);
|
1707
|
+
}
|
1708
|
+
}
|
1709
|
+
|
1710
|
+
template<int RM, int RN>
|
1711
|
+
inline void kernel(int64_t ii, int64_t jj) {
|
1712
|
+
if constexpr(RM == 4 && RN == 8) {
|
1713
|
+
KERNEL_4x8(ii,jj);
|
1714
|
+
} else if constexpr(RM == 8 && RN == 4) {
|
1715
|
+
KERNEL_8x4(ii,jj);
|
1716
|
+
} else if constexpr(RM == 8 && RN == 8) {
|
1717
|
+
KERNEL_8x8(ii,jj);
|
1718
|
+
} else {
|
1719
|
+
static_assert(false, "RN/RM values not supported");
|
1720
|
+
}
|
1721
|
+
}
|
1722
|
+
|
1723
|
+
template <int RM, int RN>
|
1724
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
1725
|
+
int64_t ytiles = (m - m0) / RM;
|
1726
|
+
int64_t xtiles = (n - n0) / RN;
|
1727
|
+
int64_t tiles = xtiles * ytiles;
|
1728
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
1729
|
+
int64_t start = duty * ith;
|
1730
|
+
int64_t end = start + duty;
|
1731
|
+
if (end > tiles)
|
1732
|
+
end = tiles;
|
1733
|
+
for (int64_t job = start; job < end; ++job) {
|
1734
|
+
int64_t ii = m0 + job / xtiles * RM;
|
1735
|
+
int64_t jj = n0 + job % xtiles * RN;
|
1736
|
+
kernel<RM, RN>(ii, jj);
|
1737
|
+
}
|
1738
|
+
}
|
1739
|
+
|
1740
|
+
const TA *const A;
|
1741
|
+
const TB *const B;
|
1742
|
+
TC *C;
|
1743
|
+
TA *At;
|
1744
|
+
TB *Bt;
|
1745
|
+
const int64_t k;
|
1746
|
+
const int64_t lda;
|
1747
|
+
const int64_t ldb;
|
1748
|
+
const int64_t ldc;
|
1749
|
+
const int ith;
|
1750
|
+
const int nth;
|
1751
|
+
};
|
1752
|
+
|
1052
1753
|
template <typename TA, typename TB, typename TC>
|
1053
1754
|
class tinyBLAS_PPC {
|
1054
1755
|
public:
|
@@ -1068,13 +1769,17 @@ class tinyBLAS_PPC {
|
|
1068
1769
|
|
1069
1770
|
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
|
1070
1771
|
|
1071
|
-
|
1772
|
+
template<typename VA>
|
1773
|
+
void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
|
1072
1774
|
int64_t i, j;
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1775
|
+
TA *aoffset = NULL, *boffset = NULL;
|
1776
|
+
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
1777
|
+
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
1778
|
+
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
1779
|
+
VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
1780
|
+
VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
1781
|
+
VA t1, t2, t3, t4, t5, t6, t7, t8;
|
1782
|
+
aoffset = const_cast<TA*>(a);
|
1078
1783
|
boffset = vec;
|
1079
1784
|
j = (rows >> 3);
|
1080
1785
|
if (j > 0) {
|
@@ -1090,9 +1795,6 @@ class tinyBLAS_PPC {
|
|
1090
1795
|
aoffset += 8 * lda;
|
1091
1796
|
i = (cols >> 3);
|
1092
1797
|
if (i > 0) {
|
1093
|
-
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
1094
|
-
vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
|
1095
|
-
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
1096
1798
|
do {
|
1097
1799
|
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
1098
1800
|
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
@@ -1172,21 +1874,19 @@ class tinyBLAS_PPC {
|
|
1172
1874
|
} while(i > 0);
|
1173
1875
|
}
|
1174
1876
|
if (cols & 4) {
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
t3 = vec_mergeh(c5, c6);
|
1189
|
-
t4 = vec_mergeh(c7, c8);
|
1877
|
+
c1[0] = vec_xl(0, aoffset1);
|
1878
|
+
c2[0] = vec_xl(0, aoffset2);
|
1879
|
+
c3[0] = vec_xl(0, aoffset3);
|
1880
|
+
c4[0] = vec_xl(0, aoffset4);
|
1881
|
+
c5[0] = vec_xl(0, aoffset5);
|
1882
|
+
c6[0] = vec_xl(0, aoffset6);
|
1883
|
+
c7[0] = vec_xl(0, aoffset7);
|
1884
|
+
c8[0] = vec_xl(0, aoffset8);
|
1885
|
+
|
1886
|
+
t1 = vec_mergeh(c1[0], c2[0]);
|
1887
|
+
t2 = vec_mergeh(c3[0], c4[0]);
|
1888
|
+
t3 = vec_mergeh(c5[0], c6[0]);
|
1889
|
+
t4 = vec_mergeh(c7[0], c8[0]);
|
1190
1890
|
t5 = vec_xxpermdi(t1, t2, 0);
|
1191
1891
|
t6 = vec_xxpermdi(t3, t4, 0);
|
1192
1892
|
t7 = vec_xxpermdi(t1, t2, 3);
|
@@ -1196,10 +1896,10 @@ class tinyBLAS_PPC {
|
|
1196
1896
|
vec_xst(t7, 0, boffset+8);
|
1197
1897
|
vec_xst(t8, 0, boffset+12);
|
1198
1898
|
|
1199
|
-
t1 = vec_mergel(c1, c2);
|
1200
|
-
t2 = vec_mergel(c3, c4);
|
1201
|
-
t3 = vec_mergel(c5, c6);
|
1202
|
-
t4 = vec_mergel(c7, c8);
|
1899
|
+
t1 = vec_mergel(c1[0], c2[0]);
|
1900
|
+
t2 = vec_mergel(c3[0], c4[0]);
|
1901
|
+
t3 = vec_mergel(c5[0], c6[0]);
|
1902
|
+
t4 = vec_mergel(c7[0], c8[0]);
|
1203
1903
|
t5 = vec_xxpermdi(t1, t2, 0);
|
1204
1904
|
t6 = vec_xxpermdi(t3, t4, 0);
|
1205
1905
|
t7 = vec_xxpermdi(t1, t2, 3);
|
@@ -1221,9 +1921,6 @@ class tinyBLAS_PPC {
|
|
1221
1921
|
aoffset += 4 * lda;
|
1222
1922
|
i = (cols >> 3);
|
1223
1923
|
if (i > 0) {
|
1224
|
-
__vector_pair C1, C2, C3, C4;
|
1225
|
-
vector float c1[2], c2[2], c3[2], c4[2];
|
1226
|
-
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
1227
1924
|
do {
|
1228
1925
|
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
1229
1926
|
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
@@ -1270,22 +1967,20 @@ class tinyBLAS_PPC {
|
|
1270
1967
|
}
|
1271
1968
|
|
1272
1969
|
if (cols & 4) {
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
t1 = vec_mergeh(c1, c2);
|
1281
|
-
t2 = vec_mergeh(c3, c4);
|
1970
|
+
c1[0] = vec_xl(0, aoffset1);
|
1971
|
+
c2[0] = vec_xl(0, aoffset2);
|
1972
|
+
c3[0] = vec_xl(0, aoffset3);
|
1973
|
+
c4[0] = vec_xl(0, aoffset4);
|
1974
|
+
|
1975
|
+
t1 = vec_mergeh(c1[0], c2[0]);
|
1976
|
+
t2 = vec_mergeh(c3[0], c4[0]);
|
1282
1977
|
t3 = vec_xxpermdi(t1, t2, 0);
|
1283
1978
|
t4 = vec_xxpermdi(t1, t2, 3);
|
1284
1979
|
vec_xst(t3, 0, boffset);
|
1285
1980
|
vec_xst(t4, 0, boffset+4);
|
1286
1981
|
|
1287
|
-
t1 = vec_mergel(c1, c2);
|
1288
|
-
t2 = vec_mergel(c3, c4);
|
1982
|
+
t1 = vec_mergel(c1[0], c2[0]);
|
1983
|
+
t2 = vec_mergel(c3[0], c4[0]);
|
1289
1984
|
t3 = vec_xxpermdi(t1, t2, 0);
|
1290
1985
|
t4 = vec_xxpermdi(t1, t2, 3);
|
1291
1986
|
vec_xst(t3, 0, boffset+8);
|
@@ -1297,21 +1992,19 @@ class tinyBLAS_PPC {
|
|
1297
1992
|
aoffset2 = aoffset1 + lda;
|
1298
1993
|
aoffset3 = aoffset2 + lda;
|
1299
1994
|
if (cols & 4) {
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
t1 = vec_mergeh(c1, c2);
|
1307
|
-
t2 = vec_mergeh(c3, c4);
|
1995
|
+
c1[0] = vec_xl(0, aoffset1);
|
1996
|
+
c2[0] = vec_xl(0, aoffset2);
|
1997
|
+
c3[0] = vec_xl(0, aoffset3);
|
1998
|
+
|
1999
|
+
t1 = vec_mergeh(c1[0], c2[0]);
|
2000
|
+
t2 = vec_mergeh(c3[0], c4[0]);
|
1308
2001
|
t3 = vec_xxpermdi(t1, t2, 0);
|
1309
2002
|
t4 = vec_xxpermdi(t1, t2, 3);
|
1310
2003
|
vec_xst(t3, 0, boffset);
|
1311
2004
|
vec_xst(t4, 0, boffset+4);
|
1312
2005
|
|
1313
|
-
t1 = vec_mergel(c1, c2);
|
1314
|
-
t2 = vec_mergel(c3, c4);
|
2006
|
+
t1 = vec_mergel(c1[0], c2[0]);
|
2007
|
+
t2 = vec_mergel(c3[0], c4[0]);
|
1315
2008
|
t3 = vec_xxpermdi(t1, t2, 0);
|
1316
2009
|
t4 = vec_xxpermdi(t1, t2, 3);
|
1317
2010
|
vec_xst(t3, 0, boffset+8);
|
@@ -1319,14 +2012,13 @@ class tinyBLAS_PPC {
|
|
1319
2012
|
}
|
1320
2013
|
}
|
1321
2014
|
}
|
1322
|
-
|
1323
2015
|
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
1324
2016
|
vec_t vec_A[4], vec_B[4], vec_C[4];
|
1325
2017
|
acc_t acc_0;
|
1326
2018
|
__builtin_mma_xxsetaccz(&acc_0);
|
1327
2019
|
for (int l = 0; l < k; l+=4) {
|
1328
|
-
|
1329
|
-
|
2020
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
|
2021
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
|
1330
2022
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
1331
2023
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
1332
2024
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
@@ -1341,8 +2033,8 @@ class tinyBLAS_PPC {
|
|
1341
2033
|
__builtin_mma_xxsetaccz(&acc_0);
|
1342
2034
|
__builtin_mma_xxsetaccz(&acc_1);
|
1343
2035
|
for (int64_t l = 0; l < k; l+=4) {
|
1344
|
-
|
1345
|
-
|
2036
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
|
2037
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
|
1346
2038
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
|
1347
2039
|
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
|
1348
2040
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
|
@@ -1362,8 +2054,8 @@ class tinyBLAS_PPC {
|
|
1362
2054
|
__builtin_mma_xxsetaccz(&acc_0);
|
1363
2055
|
__builtin_mma_xxsetaccz(&acc_1);
|
1364
2056
|
for (int64_t l = 0; l < k; l+=4) {
|
1365
|
-
|
1366
|
-
|
2057
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
|
2058
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
|
1367
2059
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
|
1368
2060
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
|
1369
2061
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
|
@@ -1385,8 +2077,8 @@ class tinyBLAS_PPC {
|
|
1385
2077
|
__builtin_mma_xxsetaccz(&acc_2);
|
1386
2078
|
__builtin_mma_xxsetaccz(&acc_3);
|
1387
2079
|
for (int l = 0; l < k; l+=8) {
|
1388
|
-
|
1389
|
-
|
2080
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
|
2081
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
|
1390
2082
|
for(int x = 0; x < 16; x+=2) {
|
1391
2083
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
|
1392
2084
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
|
@@ -1569,15 +2261,15 @@ class tinyBLAS_PPC {
|
|
1569
2261
|
vec_t vec_A[4], vec_B[4];
|
1570
2262
|
for (int l=0; l<k; l+=4) {
|
1571
2263
|
if (RN >= 4 && RM == 1) {
|
1572
|
-
|
1573
|
-
|
2264
|
+
TA* a = const_cast<TA*>(A+(ii)*lda+l);
|
2265
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
|
1574
2266
|
vec_A[0] = (vec_t)vec_xl(0,a);
|
1575
|
-
vec_A[1] = (vec_t)vec_splats(*((
|
1576
|
-
vec_A[2] = (vec_t)vec_splats(*((
|
1577
|
-
vec_A[3] = (vec_t)vec_splats(*((
|
2267
|
+
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
|
2268
|
+
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
|
2269
|
+
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
|
1578
2270
|
} else {
|
1579
|
-
|
1580
|
-
|
2271
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
2272
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
1581
2273
|
}
|
1582
2274
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
1583
2275
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
@@ -1587,7 +2279,7 @@ class tinyBLAS_PPC {
|
|
1587
2279
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
1588
2280
|
for (int I = 0; I < RM; I++) {
|
1589
2281
|
for (int J = 0; J < RN; J++) {
|
1590
|
-
*((
|
2282
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
1591
2283
|
}
|
1592
2284
|
}
|
1593
2285
|
}
|
@@ -1810,6 +2502,20 @@ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t m, in
|
|
1810
2502
|
params->ith, params->nth};
|
1811
2503
|
tb.matmul(m, n);
|
1812
2504
|
return true;
|
2505
|
+
|
2506
|
+
#elif defined(__MMA__)
|
2507
|
+
if (n < 8 && n != 4)
|
2508
|
+
return false;
|
2509
|
+
if (m < 8 && m != 4)
|
2510
|
+
return false;
|
2511
|
+
tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
|
2512
|
+
k, (const block_q8_0 *)A, lda,
|
2513
|
+
(const block_q8_0 *)B, ldb,
|
2514
|
+
(float *)C, ldc,
|
2515
|
+
params->ith, params->nth};
|
2516
|
+
tb.matmul(m, n);
|
2517
|
+
return true;
|
2518
|
+
|
1813
2519
|
#else
|
1814
2520
|
return false;
|
1815
2521
|
#endif
|