cui-llama.rn 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +5 -7
- package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
- package/android/src/main/jni.cpp +9 -9
- package/cpp/common.cpp +21 -40
- package/cpp/common.h +21 -12
- package/cpp/ggml-backend-impl.h +38 -20
- package/cpp/ggml-backend-reg.cpp +216 -87
- package/cpp/ggml-backend.h +1 -0
- package/cpp/ggml-common.h +42 -48
- package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +591 -152
- package/cpp/ggml-cpu-aarch64.h +2 -26
- package/cpp/ggml-cpu-traits.cpp +36 -0
- package/cpp/ggml-cpu-traits.h +38 -0
- package/cpp/ggml-cpu.c +14122 -13971
- package/cpp/ggml-cpu.cpp +618 -715
- package/cpp/ggml-cpu.h +0 -17
- package/cpp/ggml-impl.h +6 -6
- package/cpp/ggml-metal.m +482 -24
- package/cpp/ggml-quants.c +0 -9
- package/cpp/ggml-threading.h +4 -2
- package/cpp/ggml.c +132 -43
- package/cpp/ggml.h +44 -13
- package/cpp/llama-sampling.cpp +35 -90
- package/cpp/llama-vocab.cpp +2 -1
- package/cpp/llama.cpp +737 -233
- package/cpp/llama.h +20 -16
- package/cpp/sampling.cpp +11 -16
- package/cpp/speculative.cpp +4 -0
- package/cpp/unicode.cpp +51 -51
- package/cpp/unicode.h +9 -10
- package/lib/commonjs/index.js +38 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/index.js +36 -0
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +2 -3
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +36 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +3 -3
- package/src/index.ts +46 -2
- package/cpp/amx/amx.cpp +0 -196
- package/cpp/amx/amx.h +0 -20
- package/cpp/amx/common.h +0 -101
- package/cpp/amx/mmq.cpp +0 -2524
- package/cpp/amx/mmq.h +0 -16
- package/cpp/ggml-aarch64.c +0 -129
- package/cpp/ggml-aarch64.h +0 -19
@@ -1,20 +1,57 @@
|
|
1
|
-
#define
|
1
|
+
#define LM_GGML_COMMON_IMPL_CPP
|
2
|
+
#define LM_GGML_COMMON_DECL_CPP
|
2
3
|
#include "ggml-common.h"
|
4
|
+
#include "ggml-backend-impl.h"
|
3
5
|
|
4
6
|
#include "ggml-quants.h"
|
5
7
|
#include "ggml-impl.h"
|
6
8
|
#include "ggml-cpu.h"
|
7
9
|
#include "ggml-cpu-impl.h"
|
10
|
+
#include "ggml-cpu-traits.h"
|
8
11
|
|
9
|
-
#include <
|
10
|
-
#include <
|
11
|
-
#include <
|
12
|
-
#include <
|
13
|
-
#include <
|
14
|
-
#include <
|
12
|
+
#include <cmath>
|
13
|
+
#include <cstring>
|
14
|
+
#include <cassert>
|
15
|
+
#include <cfloat>
|
16
|
+
#include <cstdlib> // for qsort
|
17
|
+
#include <cstdio> // for LM_GGML_ASSERT
|
15
18
|
|
16
19
|
#include "ggml-cpu-aarch64.h"
|
17
20
|
|
21
|
+
// TODO: move to include file?
|
22
|
+
template <int K> constexpr int QK_0() {
|
23
|
+
if constexpr (K == 4) {
|
24
|
+
return QK4_0;
|
25
|
+
}
|
26
|
+
if constexpr (K == 8) {
|
27
|
+
return QK8_0;
|
28
|
+
}
|
29
|
+
return -1;
|
30
|
+
}
|
31
|
+
|
32
|
+
template <int K, int N> struct block {
|
33
|
+
lm_ggml_half d[N]; // deltas for N qK_0 blocks
|
34
|
+
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
35
|
+
};
|
36
|
+
|
37
|
+
// control size
|
38
|
+
static_assert(sizeof(block<4, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
39
|
+
static_assert(sizeof(block<4, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
40
|
+
static_assert(sizeof(block<8, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
41
|
+
static_assert(sizeof(block<8, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
42
|
+
|
43
|
+
using block_q4_0x4 = block<4, 4>;
|
44
|
+
using block_q4_0x8 = block<4, 8>;
|
45
|
+
using block_q8_0x4 = block<8, 4>;
|
46
|
+
using block_q8_0x8 = block<8, 8>;
|
47
|
+
|
48
|
+
struct block_iq4_nlx4 {
|
49
|
+
lm_ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
50
|
+
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
51
|
+
};
|
52
|
+
|
53
|
+
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
54
|
+
|
18
55
|
#if defined(__GNUC__)
|
19
56
|
#pragma GCC diagnostic ignored "-Woverlength-strings"
|
20
57
|
#elif defined(_MSC_VER)
|
@@ -128,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
|
|
128
165
|
}
|
129
166
|
|
130
167
|
static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
|
131
|
-
#if defined(
|
168
|
+
#if defined(__AVX512VNNI__)
|
132
169
|
const __m512i zero = _mm512_setzero_si512();
|
133
170
|
return _mm512_dpbusd_epi32(zero, ax, sy);
|
134
171
|
#else
|
@@ -185,12 +222,12 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
|
|
185
222
|
|
186
223
|
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
187
224
|
|
188
|
-
static void quantize_q8_0_4x4(const float *
|
225
|
+
static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
189
226
|
assert(QK8_0 == 32);
|
190
227
|
assert(k % QK8_0 == 0);
|
191
228
|
const int nb = k / QK8_0;
|
192
229
|
|
193
|
-
block_q8_0x4 *
|
230
|
+
block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
194
231
|
|
195
232
|
#if defined(__ARM_NEON)
|
196
233
|
float32x4_t srcv[4][8];
|
@@ -279,12 +316,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
|
|
279
316
|
#endif
|
280
317
|
}
|
281
318
|
|
282
|
-
static void quantize_q8_0_4x8(const float *
|
319
|
+
static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
283
320
|
assert(QK8_0 == 32);
|
284
321
|
assert(k % QK8_0 == 0);
|
285
322
|
const int nb = k / QK8_0;
|
286
323
|
|
287
|
-
block_q8_0x4 *
|
324
|
+
block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
288
325
|
|
289
326
|
#if defined(__ARM_NEON)
|
290
327
|
float32x4_t srcv[4][8];
|
@@ -494,7 +531,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
|
|
494
531
|
#endif
|
495
532
|
}
|
496
533
|
|
497
|
-
void quantize_mat_q8_0(const float *
|
534
|
+
static void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
|
498
535
|
assert(nrow == 4);
|
499
536
|
UNUSED(nrow);
|
500
537
|
if (blck_size_interleave == 4) {
|
@@ -506,7 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
|
|
506
543
|
}
|
507
544
|
}
|
508
545
|
|
509
|
-
void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float *
|
546
|
+
static void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
510
547
|
const int qk = QK8_0;
|
511
548
|
const int nb = n / qk;
|
512
549
|
const int ncols_interleaved = 4;
|
@@ -591,7 +628,7 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
|
591
628
|
}
|
592
629
|
}
|
593
630
|
|
594
|
-
void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float *
|
631
|
+
static void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
595
632
|
const int qk = QK8_0;
|
596
633
|
const int nb = n / qk;
|
597
634
|
const int ncols_interleaved = 4;
|
@@ -701,7 +738,7 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
701
738
|
}
|
702
739
|
}
|
703
740
|
|
704
|
-
void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float *
|
741
|
+
static void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
705
742
|
const int qk = QK8_0;
|
706
743
|
const int nb = n / qk;
|
707
744
|
const int ncols_interleaved = 8;
|
@@ -974,7 +1011,7 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
974
1011
|
}
|
975
1012
|
}
|
976
1013
|
|
977
|
-
void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float *
|
1014
|
+
static void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
978
1015
|
const int qk = QK8_0;
|
979
1016
|
const int nb = n / qk;
|
980
1017
|
const int ncols_interleaved = 4;
|
@@ -1070,7 +1107,7 @@ void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
|
|
1070
1107
|
}
|
1071
1108
|
}
|
1072
1109
|
|
1073
|
-
void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float *
|
1110
|
+
static void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
1074
1111
|
const int qk = QK8_0;
|
1075
1112
|
const int nb = n / qk;
|
1076
1113
|
const int ncols_interleaved = 4;
|
@@ -1586,7 +1623,7 @@ void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
|
1586
1623
|
}
|
1587
1624
|
}
|
1588
1625
|
|
1589
|
-
void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float *
|
1626
|
+
static void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
1590
1627
|
const int qk = QK8_0;
|
1591
1628
|
const int nb = n / qk;
|
1592
1629
|
const int ncols_interleaved = 4;
|
@@ -2040,7 +2077,7 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2040
2077
|
}
|
2041
2078
|
}
|
2042
2079
|
|
2043
|
-
void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float *
|
2080
|
+
static void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
2044
2081
|
const int qk = QK8_0;
|
2045
2082
|
const int nb = n / qk;
|
2046
2083
|
const int ncols_interleaved = 8;
|
@@ -2560,31 +2597,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2560
2597
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
2561
2598
|
|
2562
2599
|
// Shuffle pattern one - right side input
|
2563
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2564
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2600
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2601
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2565
2602
|
|
2566
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2567
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2603
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2604
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2568
2605
|
|
2569
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2570
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2606
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2607
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2571
2608
|
|
2572
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2573
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2609
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2610
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2574
2611
|
|
2575
2612
|
// Shuffle pattern two - right side input
|
2576
2613
|
|
2577
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2578
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2614
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2615
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2579
2616
|
|
2580
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2581
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2617
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2618
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2582
2619
|
|
2583
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2584
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2620
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2621
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2585
2622
|
|
2586
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2587
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2623
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2624
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2588
2625
|
|
2589
2626
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
2590
2627
|
const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
@@ -2618,31 +2655,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2618
2655
|
|
2619
2656
|
// Shuffle pattern one - left side input
|
2620
2657
|
|
2621
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2622
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2658
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2659
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2623
2660
|
|
2624
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2625
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2661
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2662
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2626
2663
|
|
2627
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2628
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2664
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2665
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2629
2666
|
|
2630
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2631
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2667
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2668
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2632
2669
|
|
2633
2670
|
// Shuffle pattern two - left side input
|
2634
2671
|
|
2635
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2636
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2672
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2673
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2637
2674
|
|
2638
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2639
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2675
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2676
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2640
2677
|
|
2641
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2642
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2678
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2679
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2643
2680
|
|
2644
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2645
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2681
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2682
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2646
2683
|
|
2647
2684
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
2648
2685
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
@@ -2671,10 +2708,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2671
2708
|
|
2672
2709
|
|
2673
2710
|
// Straighten out to make 4 row vectors
|
2674
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
2675
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
2676
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
2677
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
2711
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
2712
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
2713
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
2714
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
2678
2715
|
|
2679
2716
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
2680
2717
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
@@ -2753,31 +2790,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2753
2790
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
2754
2791
|
|
2755
2792
|
// Shuffle pattern one - right side input
|
2756
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2757
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2793
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2794
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2758
2795
|
|
2759
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2760
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2796
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2797
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2761
2798
|
|
2762
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2763
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2799
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2800
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2764
2801
|
|
2765
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2766
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2802
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2803
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2767
2804
|
|
2768
2805
|
// Shuffle pattern two - right side input
|
2769
2806
|
|
2770
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2771
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2807
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2808
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2772
2809
|
|
2773
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2774
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2810
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2811
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2775
2812
|
|
2776
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2777
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2813
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2814
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2778
2815
|
|
2779
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2780
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2816
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2817
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2781
2818
|
|
2782
2819
|
|
2783
2820
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
@@ -2809,31 +2846,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2809
2846
|
|
2810
2847
|
// Shuffle pattern one - left side input
|
2811
2848
|
|
2812
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2813
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2849
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2850
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2814
2851
|
|
2815
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2816
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2852
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2853
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2817
2854
|
|
2818
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2819
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2855
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2856
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2820
2857
|
|
2821
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2822
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2858
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2859
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2823
2860
|
|
2824
2861
|
// Shuffle pattern two - left side input
|
2825
2862
|
|
2826
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2827
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2863
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2864
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2828
2865
|
|
2829
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2830
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2866
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2867
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2831
2868
|
|
2832
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2833
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2869
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2870
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2834
2871
|
|
2835
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2836
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2872
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2873
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2837
2874
|
|
2838
2875
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
2839
2876
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
@@ -2862,10 +2899,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2862
2899
|
|
2863
2900
|
|
2864
2901
|
// Straighten out to make 4 row vectors
|
2865
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
2866
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
2867
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
2868
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
2902
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
2903
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
2904
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
2905
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
2869
2906
|
|
2870
2907
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
2871
2908
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
|
@@ -3460,7 +3497,7 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
3460
3497
|
}
|
3461
3498
|
}
|
3462
3499
|
|
3463
|
-
void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float *
|
3500
|
+
static void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
3464
3501
|
const int qk = QK8_0;
|
3465
3502
|
const int nb = n / qk;
|
3466
3503
|
const int ncols_interleaved = 4;
|
@@ -3571,7 +3608,6 @@ void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
|
|
3571
3608
|
}
|
3572
3609
|
}
|
3573
3610
|
|
3574
|
-
// FIXME: this code is duplicated from ggml-aarch64.c
|
3575
3611
|
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
3576
3612
|
block_q4_0x4 out;
|
3577
3613
|
|
@@ -3641,20 +3677,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|
3641
3677
|
return out;
|
3642
3678
|
}
|
3643
3679
|
|
3644
|
-
static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void *
|
3680
|
+
static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3645
3681
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
|
3646
3682
|
LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
3683
|
+
constexpr int nrows_interleaved = 4;
|
3647
3684
|
|
3648
3685
|
block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
|
3649
3686
|
const block_q4_0 * src = (const block_q4_0 *)data;
|
3650
3687
|
block_q4_0 dst_tmp[4];
|
3651
|
-
int nrow = t
|
3652
|
-
int nrows_interleaved = 4;
|
3688
|
+
int nrow = lm_ggml_nrows(t);
|
3653
3689
|
int nblocks = t->ne[0] / QK4_0;
|
3654
3690
|
|
3655
3691
|
LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
3656
3692
|
|
3657
|
-
if (
|
3693
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
3658
3694
|
return -1;
|
3659
3695
|
}
|
3660
3696
|
|
@@ -3672,20 +3708,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_bl
|
|
3672
3708
|
LM_GGML_UNUSED(data_size);
|
3673
3709
|
}
|
3674
3710
|
|
3675
|
-
static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor *t, int interleave_block, const void *
|
3711
|
+
static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3676
3712
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
|
3677
3713
|
LM_GGML_ASSERT(interleave_block == 8);
|
3714
|
+
constexpr int nrows_interleaved = 8;
|
3678
3715
|
|
3679
3716
|
block_q4_0x8 * dst = (block_q4_0x8*)t->data;
|
3680
3717
|
const block_q4_0 * src = (const block_q4_0*) data;
|
3681
3718
|
block_q4_0 dst_tmp[8];
|
3682
|
-
int nrow = t
|
3683
|
-
int nrows_interleaved = 8;
|
3719
|
+
int nrow = lm_ggml_nrows(t);
|
3684
3720
|
int nblocks = t->ne[0] / QK4_0;
|
3685
3721
|
|
3686
3722
|
LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
3687
3723
|
|
3688
|
-
if (
|
3724
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
3689
3725
|
return -1;
|
3690
3726
|
}
|
3691
3727
|
|
@@ -3712,16 +3748,18 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
|
|
3712
3748
|
|
3713
3749
|
const int end = QK4_NL * 2 / blck_size_interleave;
|
3714
3750
|
|
3715
|
-
|
3716
|
-
|
3717
|
-
|
3718
|
-
|
3719
|
-
|
3720
|
-
|
3721
|
-
|
3722
|
-
|
3723
|
-
|
3724
|
-
}
|
3751
|
+
// TODO: this branch seems wrong
|
3752
|
+
//if (blck_size_interleave == 8) {
|
3753
|
+
// for (int i = 0; i < end; ++i) {
|
3754
|
+
// int src_id = i % 4;
|
3755
|
+
// int src_offset = (i / 4) * blck_size_interleave;
|
3756
|
+
// int dst_offset = i * blck_size_interleave;
|
3757
|
+
|
3758
|
+
// // Using memcpy to avoid unaligned memory accesses
|
3759
|
+
// memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
3760
|
+
// }
|
3761
|
+
//} else
|
3762
|
+
if (blck_size_interleave == 4) {
|
3725
3763
|
for (int i = 0; i < end; ++i) {
|
3726
3764
|
int src_id = i % 4;
|
3727
3765
|
int src_offset = (i / 4) * blck_size_interleave;
|
@@ -3736,20 +3774,21 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
|
|
3736
3774
|
return out;
|
3737
3775
|
}
|
3738
3776
|
|
3739
|
-
static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void *
|
3777
|
+
static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3740
3778
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_IQ4_NL);
|
3741
|
-
LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
3779
|
+
//LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
3780
|
+
LM_GGML_ASSERT(interleave_block == 4);
|
3742
3781
|
|
3743
3782
|
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
|
3744
3783
|
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
3745
3784
|
block_iq4_nl dst_tmp[4];
|
3746
|
-
int nrow = t
|
3785
|
+
int nrow = lm_ggml_nrows(t);
|
3747
3786
|
int nrows_interleaved = 4;
|
3748
3787
|
int nblocks = t->ne[0] / QK4_0;
|
3749
3788
|
|
3750
3789
|
LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
3751
3790
|
|
3752
|
-
if (
|
3791
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
3753
3792
|
return -1;
|
3754
3793
|
}
|
3755
3794
|
|
@@ -3767,57 +3806,457 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleav
|
|
3767
3806
|
LM_GGML_UNUSED(data_size);
|
3768
3807
|
}
|
3769
3808
|
|
3770
|
-
|
3771
|
-
|
3772
|
-
|
3773
|
-
|
3774
|
-
|
3809
|
+
namespace ggml::cpu::aarch64 {
|
3810
|
+
// repack
|
3811
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
3812
|
+
int repack(struct lm_ggml_tensor *, const void *, size_t);
|
3813
|
+
|
3814
|
+
// TODO: generalise.
|
3815
|
+
template <> int repack<block_q4_0, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3816
|
+
return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
|
3817
|
+
}
|
3818
|
+
|
3819
|
+
template <> int repack<block_q4_0, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3820
|
+
return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
|
3821
|
+
}
|
3822
|
+
|
3823
|
+
template <> int repack<block_q4_0, 8, 8>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3824
|
+
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
3825
|
+
}
|
3826
|
+
|
3827
|
+
template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3828
|
+
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
3829
|
+
}
|
3830
|
+
|
3831
|
+
// TODO: needs to be revisited
|
3832
|
+
//template <> int repack<block_iq4_nl, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3833
|
+
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
|
3834
|
+
//}
|
3835
|
+
|
3836
|
+
// gemv
|
3837
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
3838
|
+
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
3839
|
+
|
3840
|
+
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3841
|
+
lm_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3842
|
+
}
|
3843
|
+
|
3844
|
+
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3845
|
+
lm_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3846
|
+
}
|
3847
|
+
|
3848
|
+
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3849
|
+
lm_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3850
|
+
}
|
3851
|
+
|
3852
|
+
template <>
|
3853
|
+
void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3854
|
+
lm_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3855
|
+
}
|
3856
|
+
|
3857
|
+
// gemm
|
3858
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
3859
|
+
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
3860
|
+
|
3861
|
+
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3862
|
+
lm_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3863
|
+
}
|
3864
|
+
|
3865
|
+
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3866
|
+
lm_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3867
|
+
}
|
3868
|
+
|
3869
|
+
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3870
|
+
lm_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3871
|
+
}
|
3872
|
+
|
3873
|
+
template <>
|
3874
|
+
void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3875
|
+
lm_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3876
|
+
}
|
3877
|
+
|
3878
|
+
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
3879
|
+
public:
|
3880
|
+
virtual int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) = 0;
|
3881
|
+
};
|
3882
|
+
|
3883
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
3884
|
+
|
3885
|
+
bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
|
3886
|
+
// not realy a LM_GGML_TYPE_Q8_0 but same size.
|
3887
|
+
switch (op->op) {
|
3888
|
+
case LM_GGML_OP_MUL_MAT:
|
3889
|
+
size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
|
3890
|
+
return true;
|
3891
|
+
case LM_GGML_OP_MUL_MAT_ID:
|
3892
|
+
size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
|
3893
|
+
size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
3894
|
+
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
3895
|
+
return true;
|
3896
|
+
default:
|
3897
|
+
// LM_GGML_ABORT("fatal error");
|
3898
|
+
break;
|
3899
|
+
}
|
3900
|
+
return false;
|
3775
3901
|
}
|
3776
3902
|
|
3777
|
-
|
3778
|
-
switch (
|
3779
|
-
|
3780
|
-
|
3781
|
-
|
3782
|
-
|
3783
|
-
|
3784
|
-
|
3785
|
-
|
3786
|
-
|
3787
|
-
|
3788
|
-
default:
|
3789
|
-
LM_GGML_ABORT("Unsupported type");
|
3903
|
+
bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
|
3904
|
+
switch (op->op) {
|
3905
|
+
case LM_GGML_OP_MUL_MAT:
|
3906
|
+
forward_mul_mat(params, op);
|
3907
|
+
return true;
|
3908
|
+
case LM_GGML_OP_MUL_MAT_ID:
|
3909
|
+
forward_mul_mat_id(params, op);
|
3910
|
+
return true;
|
3911
|
+
default:
|
3912
|
+
// LM_GGML_ABORT("fatal error");
|
3913
|
+
break;
|
3790
3914
|
}
|
3791
|
-
|
3792
|
-
|
3793
|
-
|
3794
|
-
|
3795
|
-
|
3796
|
-
|
3797
|
-
|
3915
|
+
return false;
|
3916
|
+
}
|
3917
|
+
|
3918
|
+
void forward_mul_mat(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
|
3919
|
+
const lm_ggml_tensor * src0 = op->src[0];
|
3920
|
+
const lm_ggml_tensor * src1 = op->src[1];
|
3921
|
+
lm_ggml_tensor * dst = op;
|
3922
|
+
|
3923
|
+
LM_GGML_TENSOR_BINARY_OP_LOCALS
|
3924
|
+
|
3925
|
+
const int ith = params->ith;
|
3926
|
+
const int nth = params->nth;
|
3927
|
+
|
3928
|
+
LM_GGML_ASSERT(ne0 == ne01);
|
3929
|
+
LM_GGML_ASSERT(ne1 == ne11);
|
3930
|
+
LM_GGML_ASSERT(ne2 == ne12);
|
3931
|
+
LM_GGML_ASSERT(ne3 == ne13);
|
3932
|
+
|
3933
|
+
// dst cannot be transposed or permuted
|
3934
|
+
LM_GGML_ASSERT(nb0 == sizeof(float));
|
3935
|
+
LM_GGML_ASSERT(nb0 <= nb1);
|
3936
|
+
LM_GGML_ASSERT(nb1 <= nb2);
|
3937
|
+
LM_GGML_ASSERT(nb2 <= nb3);
|
3938
|
+
|
3939
|
+
LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
|
3940
|
+
|
3941
|
+
LM_GGML_ASSERT(lm_ggml_n_dims(op->src[0]) == 2);
|
3942
|
+
// LM_GGML_ASSERT(lm_ggml_n_dims(op->src[1]) == 2);
|
3943
|
+
|
3944
|
+
char * wdata = static_cast<char *>(params->wdata);
|
3945
|
+
const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
|
3946
|
+
|
3947
|
+
assert(params->wsize >= nbw1 * ne11);
|
3948
|
+
|
3949
|
+
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
|
3950
|
+
|
3951
|
+
int64_t i11_processed = 0;
|
3952
|
+
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
3953
|
+
quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
|
3954
|
+
INTER_SIZE);
|
3955
|
+
}
|
3956
|
+
i11_processed = ne11 - ne11 % 4;
|
3957
|
+
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
3958
|
+
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
3959
|
+
}
|
3960
|
+
|
3961
|
+
lm_ggml_barrier(params->threadpool);
|
3962
|
+
|
3963
|
+
const void * src1_wdata = params->wdata;
|
3964
|
+
const size_t src1_col_stride = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
|
3965
|
+
int64_t src0_start = (ith * ne01) / nth;
|
3966
|
+
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
3967
|
+
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
3968
|
+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
3969
|
+
if (src0_start >= src0_end) {
|
3970
|
+
return;
|
3971
|
+
}
|
3972
|
+
|
3973
|
+
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
3974
|
+
if (ne11 > 3) {
|
3975
|
+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
|
3976
|
+
(const char *) src0->data + src0_start * nb01,
|
3977
|
+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
3978
|
+
}
|
3979
|
+
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
3980
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
3981
|
+
(const char *) src0->data + src0_start * nb01,
|
3982
|
+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
3983
|
+
src0_end - src0_start);
|
3798
3984
|
}
|
3799
|
-
} else {
|
3800
|
-
LM_GGML_ABORT("Unsupported type");
|
3801
3985
|
}
|
3802
|
-
}
|
3803
3986
|
|
3804
|
-
|
3987
|
+
void forward_mul_mat_id(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
|
3988
|
+
const lm_ggml_tensor * src0 = op->src[0];
|
3989
|
+
const lm_ggml_tensor * src1 = op->src[1];
|
3990
|
+
const lm_ggml_tensor * ids = op->src[2];
|
3991
|
+
lm_ggml_tensor * dst = op;
|
3992
|
+
|
3993
|
+
LM_GGML_TENSOR_BINARY_OP_LOCALS
|
3994
|
+
|
3995
|
+
const int ith = params->ith;
|
3996
|
+
const int nth = params->nth;
|
3997
|
+
|
3998
|
+
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
|
3999
|
+
|
4000
|
+
// we don't support permuted src0 or src1
|
4001
|
+
LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
|
4002
|
+
LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type));
|
4003
|
+
|
4004
|
+
// dst cannot be transposed or permuted
|
4005
|
+
LM_GGML_ASSERT(nb0 == sizeof(float));
|
4006
|
+
LM_GGML_ASSERT(nb0 <= nb1);
|
4007
|
+
LM_GGML_ASSERT(nb1 <= nb2);
|
4008
|
+
LM_GGML_ASSERT(nb2 <= nb3);
|
4009
|
+
|
4010
|
+
LM_GGML_ASSERT(ne03 == 1);
|
4011
|
+
LM_GGML_ASSERT(ne13 == 1);
|
4012
|
+
LM_GGML_ASSERT(ne3 == 1);
|
4013
|
+
|
4014
|
+
LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
|
4015
|
+
|
4016
|
+
// row groups
|
4017
|
+
const int n_ids = ids->ne[0]; // n_expert_used
|
4018
|
+
const int n_as = ne02; // n_expert
|
4019
|
+
|
4020
|
+
const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
|
4021
|
+
const size_t nbw2 = nbw1*ne11;
|
4022
|
+
const size_t nbw3 = nbw2*ne12;
|
4023
|
+
|
4024
|
+
struct mmid_row_mapping {
|
4025
|
+
int32_t i1;
|
4026
|
+
int32_t i2;
|
4027
|
+
};
|
4028
|
+
|
4029
|
+
LM_GGML_ASSERT(params->wsize >= (LM_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
4030
|
+
n_as * ne12 * sizeof(mmid_row_mapping)));
|
4031
|
+
|
4032
|
+
auto wdata = (char *) params->wdata;
|
4033
|
+
auto wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t));
|
4034
|
+
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
4035
|
+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
4036
|
+
|
4037
|
+
// src1: float32 => block_q8_0
|
4038
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
4039
|
+
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
4040
|
+
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
4041
|
+
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
|
4042
|
+
ne10);
|
4043
|
+
}
|
4044
|
+
}
|
4045
|
+
|
4046
|
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
|
4047
|
+
|
4048
|
+
if (ith == 0) {
|
4049
|
+
// initialize matrix_row_counts
|
4050
|
+
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
|
4051
|
+
|
4052
|
+
// group rows by src0 matrix
|
4053
|
+
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
4054
|
+
for (int32_t id = 0; id < n_ids; ++id) {
|
4055
|
+
const int32_t i02 =
|
4056
|
+
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
4057
|
+
|
4058
|
+
LM_GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
4059
|
+
|
4060
|
+
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
|
4061
|
+
matrix_row_counts[i02] += 1;
|
4062
|
+
}
|
4063
|
+
}
|
4064
|
+
}
|
4065
|
+
|
4066
|
+
lm_ggml_barrier(params->threadpool);
|
4067
|
+
|
4068
|
+
// compute each matrix multiplication in sequence
|
4069
|
+
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
4070
|
+
const int64_t cne1 = matrix_row_counts[cur_a];
|
4071
|
+
|
4072
|
+
if (cne1 == 0) {
|
4073
|
+
continue;
|
4074
|
+
}
|
4075
|
+
|
4076
|
+
auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
4077
|
+
|
4078
|
+
//const int64_t nr0 = ne01; // src0 rows
|
4079
|
+
const int64_t nr1 = cne1; // src1 rows
|
4080
|
+
|
4081
|
+
int64_t src0_cur_start = (ith * ne01) / nth;
|
4082
|
+
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
4083
|
+
src0_cur_start =
|
4084
|
+
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
4085
|
+
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
4086
|
+
|
4087
|
+
if (src0_cur_start >= src0_cur_end) return;
|
4088
|
+
|
4089
|
+
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
4090
|
+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
4091
|
+
const int id = row_mapping.i1; // selected expert index
|
4092
|
+
|
4093
|
+
const int64_t i11 = id % ne11;
|
4094
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
4095
|
+
|
4096
|
+
const int64_t i1 = id; // selected expert index
|
4097
|
+
const int64_t i2 = i12; // row
|
4098
|
+
|
4099
|
+
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
4100
|
+
|
4101
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
|
4102
|
+
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
|
4103
|
+
ne01, src0_cur + src0_cur_start * nb01,
|
4104
|
+
src1_col, 1, src0_cur_end - src0_cur_start);
|
4105
|
+
}
|
4106
|
+
}
|
4107
|
+
#undef MMID_MATRIX_ROW
|
4108
|
+
}
|
4109
|
+
|
4110
|
+
int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) override {
|
4111
|
+
LM_GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, lm_ggml_type_name(t->type),
|
4112
|
+
(int) NB_COLS, (int) INTER_SIZE);
|
4113
|
+
return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
|
4114
|
+
}
|
4115
|
+
};
|
4116
|
+
|
4117
|
+
// instance for Q4
|
4118
|
+
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
|
4119
|
+
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
|
4120
|
+
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
|
4121
|
+
|
4122
|
+
// instance for IQ4
|
4123
|
+
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
|
4124
|
+
|
4125
|
+
} // namespace ggml::cpu::aarch64
|
4126
|
+
|
4127
|
+
static const ggml::cpu::tensor_traits * lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) {
|
3805
4128
|
if (cur->type == LM_GGML_TYPE_Q4_0) {
|
3806
|
-
|
3807
|
-
|
3808
|
-
|
4129
|
+
if (lm_ggml_cpu_has_avx2() || (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) {
|
4130
|
+
if (cur->ne[1] % 8 == 0) {
|
4131
|
+
return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
|
4132
|
+
}
|
3809
4133
|
}
|
3810
4134
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) {
|
3811
|
-
|
4135
|
+
if (cur->ne[1] % 4 == 0) {
|
4136
|
+
return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
|
4137
|
+
}
|
3812
4138
|
}
|
3813
4139
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
|
3814
|
-
|
4140
|
+
if (cur->ne[1] % 4 == 0) {
|
4141
|
+
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
|
4142
|
+
}
|
3815
4143
|
}
|
3816
4144
|
} else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
|
3817
4145
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
|
3818
|
-
|
4146
|
+
if (cur->ne[1] % 4 == 0) {
|
4147
|
+
return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
|
4148
|
+
}
|
4149
|
+
}
|
4150
|
+
}
|
4151
|
+
|
4152
|
+
return nullptr;
|
4153
|
+
}
|
4154
|
+
|
4155
|
+
static void lm_ggml_backend_cpu_aarch64_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) {
|
4156
|
+
tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(lm_ggml_aarch64_get_optimal_repack_type(tensor));
|
4157
|
+
|
4158
|
+
LM_GGML_UNUSED(buffer);
|
4159
|
+
}
|
4160
|
+
|
4161
|
+
static void lm_ggml_backend_cpu_aarch64_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
|
4162
|
+
const void * data, size_t offset, size_t size) {
|
4163
|
+
LM_GGML_ASSERT(offset == 0);
|
4164
|
+
LM_GGML_ASSERT(size == lm_ggml_nbytes(tensor));
|
4165
|
+
|
4166
|
+
auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
|
4167
|
+
auto OK = tensor_traits->repack(tensor, data, size);
|
4168
|
+
|
4169
|
+
LM_GGML_ASSERT(OK == 0);
|
4170
|
+
LM_GGML_UNUSED(buffer);
|
4171
|
+
}
|
4172
|
+
|
4173
|
+
static const char * lm_ggml_backend_cpu_aarch64_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
|
4174
|
+
return "CPU_AARCH64";
|
4175
|
+
|
4176
|
+
LM_GGML_UNUSED(buft);
|
4177
|
+
}
|
4178
|
+
|
4179
|
+
static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
|
4180
|
+
lm_ggml_backend_buffer_t buffer = lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_cpu_buffer_type(), size);
|
4181
|
+
|
4182
|
+
if (buffer == nullptr) {
|
4183
|
+
return nullptr;
|
4184
|
+
}
|
4185
|
+
|
4186
|
+
buffer->buft = buft;
|
4187
|
+
buffer->iface.init_tensor = lm_ggml_backend_cpu_aarch64_buffer_init_tensor;
|
4188
|
+
buffer->iface.set_tensor = lm_ggml_backend_cpu_aarch64_buffer_set_tensor;
|
4189
|
+
return buffer;
|
4190
|
+
}
|
4191
|
+
|
4192
|
+
static size_t lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
|
4193
|
+
return TENSOR_ALIGNMENT;
|
4194
|
+
|
4195
|
+
LM_GGML_UNUSED(buft);
|
4196
|
+
}
|
4197
|
+
|
4198
|
+
namespace ggml::cpu::aarch64 {
|
4199
|
+
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
4200
|
+
bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override {
|
4201
|
+
if ( op->op == LM_GGML_OP_MUL_MAT &&
|
4202
|
+
op->src[0]->buffer &&
|
4203
|
+
(lm_ggml_n_dims(op->src[0]) == 2) &&
|
4204
|
+
op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type() &&
|
4205
|
+
lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
|
4206
|
+
) {
|
4207
|
+
if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
4208
|
+
return false;
|
4209
|
+
}
|
4210
|
+
if (op->src[1]->type == LM_GGML_TYPE_F32) {
|
4211
|
+
return true;
|
4212
|
+
}
|
4213
|
+
//if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
|
4214
|
+
// return true;
|
4215
|
+
//}
|
4216
|
+
// may be possible if Q8_0 packed...
|
4217
|
+
} else if (op->op == LM_GGML_OP_MUL_MAT_ID
|
4218
|
+
&& op->src[0]->buffer
|
4219
|
+
&& (lm_ggml_n_dims(op->src[0]) == 3)
|
4220
|
+
&& op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()
|
4221
|
+
&& lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
|
4222
|
+
) {
|
4223
|
+
if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
4224
|
+
return false;
|
4225
|
+
}
|
4226
|
+
if (op->src[1]->type == LM_GGML_TYPE_F32) {
|
4227
|
+
return true;
|
4228
|
+
}
|
4229
|
+
//if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
|
4230
|
+
// return true;
|
4231
|
+
//}
|
3819
4232
|
}
|
4233
|
+
return false;
|
3820
4234
|
}
|
3821
4235
|
|
3822
|
-
|
4236
|
+
ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override {
|
4237
|
+
if (op->op == LM_GGML_OP_MUL_MAT || op->op == LM_GGML_OP_MUL_MAT_ID) {
|
4238
|
+
if (op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()) {
|
4239
|
+
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
4240
|
+
}
|
4241
|
+
}
|
4242
|
+
return nullptr;
|
4243
|
+
}
|
4244
|
+
};
|
4245
|
+
} // namespace ggml::cpu::aarch64
|
4246
|
+
|
4247
|
+
lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void) {
|
4248
|
+
static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_aarch64 = {
|
4249
|
+
/* .iface = */ {
|
4250
|
+
/* .get_name = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_name,
|
4251
|
+
/* .alloc_buffer = */ lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
|
4252
|
+
/* .get_alignment = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment,
|
4253
|
+
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
4254
|
+
/* .get_alloc_size = */ nullptr, // defaults to lm_ggml_nbytes
|
4255
|
+
/* .is_host = */ nullptr,
|
4256
|
+
},
|
4257
|
+
/* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
|
4258
|
+
/* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
|
4259
|
+
};
|
4260
|
+
|
4261
|
+
return &lm_ggml_backend_cpu_buffer_type_aarch64;
|
3823
4262
|
}
|